paddlets.models.forecasting.dl.tft

This implementation is based on the article Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting.

class TFTModel(in_chunk_len: int, out_chunk_len: int, hidden_size: int = 64, lstm_layers_num: int = 1, attention_heads_num: int = 1, output_quantiles: ~typing.List[float] = [0.1, 0.5, 0.9], dropout: float = 0.0, skip_chunk_len: int = 0, sampling_stride: int = 1, loss_fn: ~typing.Callable[[...], ~paddle.Tensor] = <bound method QuantileRegression.loss of <paddlets.models.forecasting.dl.distributions.likelihood.QuantileRegression object>>, optimizer_fn: ~typing.Callable[[...], ~paddle.optimizer.optimizer.Optimizer] = <class 'paddle.optimizer.adam.Adam'>, optimizer_params: ~typing.Dict[str, ~typing.Any] = {'learning_rate': 0.0001}, callbacks: ~typing.List[~paddlets.models.common.callbacks.callbacks.Callback] = [], batch_size: int = 128, max_epochs: int = 10, verbose: int = 1, patience: int = 4, seed: int = 0)[源代码]

基类:PaddleBaseModelImpl

TFT模型实现。

参数
  • in_chunk_len (int) – 模型输入的时间序列长度。

  • out_chunk_len (int) – 模型输出的序列长度。

  • hidden_size (int, Optional) – TFT模型隐藏状态h大小。

  • lstm_layers_num (int, Optional) – LSTM网络的层数。

  • attention_heads_num (int, Optional) – 多头注意力模块的数量。

  • output_quantiles (List[float], Optional) – 模型输出的分位数。

  • dropout (float, Optional) – 除了最后一层RNN,神经元随机丢弃的比例。

  • skip_chunk_len (int, Optional) – 可选变量, 输入序列与输出序列之间跳过的序列长度, 既不作为特征也不作为序测目标使用, 默认值为0。

  • sampling_stride (int, optional) – 相邻样本间的采样间隔。

  • loss_fn (Callable, Optional) – 损失函数。

  • optimizer_fn (Callable, Optional) – 优化算法

  • optimizer_params (Dict, Optional) – 优化器参数。

  • callbacks (List[Callback], Optional) – 自定义callback函数。

  • batch_size (int, Optional) – 训练数据或评估数据的批大小。

  • max_epochs (int, Optional) – 训练的最大轮数。

  • verbose (int, Optional) – 模型训练过程中打印日志信息的间隔。

  • patience (int, Optional) – 模型训练过程中, 当评估指标超过一定轮数不再变优,模型提前停止训练。

  • seed (int, Optional) – 全局随机数种子, 注: 保证每次模型参数初始化一致。

predict_interpretable(tsdataset: TSDataset) Dict[str, ndarray][源代码]

输出可解释性结果。

参数

tsdataset (TSDataset) – 需要预测的数据。

返回

可解释性结果

返回类型

results(Dict[str, np.ndarray])