第三方模型自动集成

PaddleTS 提供的 make_ml_model 接口允许用户基于第三方的 scikit-learnpyod 库分别构建时序预测和时序异常检测模型。基于此能力,用户仅需开发少量代码,即可对他们的时序建模想法进行可行性和效果的快速验证,明显提升了效率。

1. 基于第三方模型进行时序预测

1.1 最小化示例

下方示例展示了如何基于 sklearn.neighbors.KNeighborsRegressor 构建时序预测模型。

from paddlets.datasets.repository import get_dataset
from paddlets.models.ml_model_wrapper import make_ml_model

from sklearn.neighbors import KNeighborsRegressor

# prepare data
tsdataset = get_dataset("UNI_WTH")

# make model based on sklearn.neighbors.KNeighborsRegressor
model = make_ml_model(
    in_chunk_len=3,
    out_chunk_len=1,
    model_class=KNeighborsRegressor
)

# fit
model.fit(train_data=tsdataset)

# predict
predicted_ds = model.predict(tsdataset)
#             WetBulbCelsius
# 2014-01-01           -1.72

1.2 将MLDataLoader转换为可训练/预测的ndarray数据

一些第三方库(如 scikit-learn )的机器学习模型通常会接收numpy.ndarray类型数据作为 fitpredict 方法的输入。但是在PaddleTS中,会使用 paddlets.models.forecasting.ml.adapter.ml_dataloader.MLDataLoader 表示可用于训练/预测的时序数据。因此,make_ml_model 提供了 udf_ml_dataloader_to_fit_ndarrayudf_ml_dataloader_to_predict_ndarray 两个可选参数,用于支持用户将 MLDataLoader 转换为 numpy.ndarray 数据对象。

make_ml_model 默认会使用 default_sklearn_ml_dataloader_to_fit_ndarraydefault_sklearn_ml_dataloader_to_predict_ndarray 两个函数将 MLDataLoader 分别转换为 fitpredict 方法可接收的 numpy.ndarray 数据。同时,用户也可以开发自定义的数据转换函数,用于得到可用于训练/预测的数据。

from paddlets.datasets.repository import get_dataset
from paddlets.models.forecasting.ml.adapter.ml_dataloader import MLDataLoader
from paddlets.models.ml_model_wrapper import make_ml_model

from sklearn.neighbors import KNeighborsRegressor

# prepare data
tsdataset = get_dataset("UNI_WTH")

# develop user-defined convert functions
def udf_ml_dataloader_to_fit_ndarray(
    ml_dataloader: MLDataLoader,
    model_init_params: Dict[str, Any],
    in_chunk_len: int,
    skip_chunk_len: int,
    out_chunk_len: int
):
    # build and return converted numpy.ndarray object that sklearn model fit method accepts.
    pass

def udf_ml_dataloader_to_predict_ndarray(
    ml_dataloader: MLDataLoader,
    model_init_params: Dict[str, Any],
    in_chunk_len: int,
    skip_chunk_len: int,
    out_chunk_len: int
):
    # build and return converted numpy.ndarray object that sklearn model predict method accepts.
    pass

# pass the above 2 udf arguments to make_ml_model
model = make_ml_model(
    in_chunk_len=3,
    out_chunk_len=1,
    model_class=KNeighborsRegressor,
    udf_ml_dataloader_to_fit_ndarray=udf_ml_dataloader_to_fit_ndarray,
    udf_ml_dataloader_to_fit_ndarray=udf_ml_dataloader_to_predict_ndarray
)

# fit
model.fit(train_data=tsdataset)

# predict
predicted_ds = model.predict(tsdataset)

1.3 多时间点时序预测

通过第三方模型构建的时序模型也可以通过调用 recursive_predict 实现多时间点预测。

from paddlets.datasets.repository import get_dataset
from paddlets.models.forecasting.ml.ml_model_wrapper import make_ml_model

# prepare data
tsdataset = get_dataset("UNI_WTH")

# make model
model = make_ml_model(
    in_chunk_len=3,
    out_chunk_len=1,
    model_class=KNeighborsRegressor
)

# fit
model.fit(train_data=tsdataset)

# recursively predict
recursively_predicted_ds = model.recursive_predict(tsdataset=tsdataset, predict_length=4)
#                      WetBulbCelsius
# 2014-01-01 00:00:00           -1.72
# 2014-01-01 01:00:00           -1.88
# 2014-01-01 02:00:00           -2.18
# 2014-01-01 03:00:00           -2.44

2 基于第三方模型进行时序异常检测

2.1 最小化示例

下方示例展示了如何基于 pyod.models.knn.KNN 构建时序异常检测模型。由于它与构建时序预测模型的接口一致,因此你可以参考上文中 1.1 节内容了解如何通过定义一个udf函数来为构建的模型自定义输入的向量。

from paddlets.datasets.repository import get_dataset
from paddlets.models.ml_model_wrapper import make_ml_model

from pyod.models.knn import KNN

# prepare data
tsdataset = get_dataset("WTH")

# make model based on pyod.models.knn.KNN
model = make_ml_model(
    in_chunk_len=3,
    model_class=KNN
)

# fit
model.fit(train_data=tsdataset)

# predict
predicted_ds = model.predict(tsdataset)
#                      WetBulbCelsius
# date
# 2010-01-01 02:00:00               0
# 2010-01-01 03:00:00               0
# 2010-01-01 04:00:00               1
# 2010-01-01 05:00:00               0
# 2010-01-01 06:00:00               0
# ...                             ...
# 2013-12-31 19:00:00               1
# 2013-12-31 20:00:00               1
# 2013-12-31 21:00:00               1
# 2013-12-31 22:00:00               1
# 2013-12-31 23:00:00               1

# [35062 rows x 1 columns]