paddlets.models.representation.dl.cost

class CoST(segment_size: int, sampling_stride: int = 1, optimizer_fn: ~typing.Callable[[...], ~paddle.optimizer.optimizer.Optimizer] = <class 'paddle.optimizer.momentum.Momentum'>, optimizer_params: ~typing.Dict[str, ~typing.Any] = {'learning_rate': 0.001}, callbacks: ~typing.List[~paddlets.models.common.callbacks.callbacks.Callback] = [], batch_size: int = 128, max_epochs: int = 10, verbose: int = 1, seed: ~typing.Union[None, int] = None, repr_dims: int = 320, hidden_dims: int = 64, num_layers: int = 10, queue_size: int = 256, temperature: float = 0.07, alpha: float = 0.0005)[源代码]

基类:ReprBaseModel

CoST[1] 是2022年提出的一种时序表征模型(适用于长时序预测的新型时序表征框架), 它使用对比学习方法去学习解耦的季节-趋势表示(包含时域对比损失和频域对比损失, 分别学习可区分的趋势表征和季节表征)

[1] Woo G, et al. “CoST: Contrastive Learning of Disentangled Seasonal-Trend Representations for Time Series Forecasting”, https://arxiv.org/pdf/2202.01575.pdf

参数
  • segment_size (int) – 时序片段的长度.

  • sampling_stride (int) – 相邻样本间的采样间隔.

  • optimizer_fn (Callable[..., Optimizer]) – 优化算法.

  • optimizer_params (Dict[str, Any]) – 优化器参数.

  • callbacks (List[Callback]) – 自定义callback函数.

  • batch_size (int) – 训练数据的批大小.

  • max_epochs (int) – 训练的最大轮数.

  • verbose (int) – 模型训练过程中打印日志信息的间隔

  • seed (int|None) – 全局随机数种子, 注: 保证每次模型参数初始化一致.

  • repr_dims (int) – 表征向量的维度.

  • hidden_dims (int) – 空洞卷积网络的隐层通道数.

  • num_layers (int) – 空洞卷积网络的层数.

  • queue_size (int) – 用于保存负例样本的动态队列大小.

  • temperature (float) – 时域对比损失中的温度系数.

  • alpha (float) – 该参数用于调整损失函数中频域对比损失的占比.

_segment_size

时序片段的长度.

Type

int

_sampling_stride

相邻样本间的采样间隔.

Type

int

_optimizer_fn

优化算法.

Type

Callable[…, Optimizer]

_optimizer_params

优化器参数.

Type

Dict[str, Any]

_callbacks

自定义callback函数.

Type

List[Callback]

_batch_size

训练数据的批大小.

Type

int

_max_epochs

训练的最大轮数.

Type

int

_verbose

模型训练过程中打印日志信息的间隔

Type

int

_seed

全局随机数种子, 注: 保证每次模型参数初始化一致.

Type

int|None

_repr_dims

表征向量的维度.

Type

int

_hidden_dims

空洞卷积网络的隐层通道数.

Type

int

_num_layers

空洞卷积网络的层数.

Type

int

_queue_size

用于保存负例样本的动态队列大小.

Type

int

_temperature

时域对比损失中的温度系数.

Type

float

_alpha

该参数用于调整损失函数中频域对比损失的占比.

Type

float