paddlets.models.representation.task.repr_classifier

class ReprClassifier(repr_model: ReprBaseModel, repr_model_params: Optional[dict] = None, encode_params: Optional[dict] = None, downstream_learner: Optional[Callable] = None, verbose: bool = False)[源代码]

基类:StackingEnsembleBase

表征分类

参数
  • repr_model (ReprBasemodel) – 所用的表征模型类

  • repr_model_params (dict) – 表征模型的参数

  • encode_params (dict) – 表征模型的encode参数

  • downstream_learner (Callable) – 下游分类器,需要是一个sklearn形式的分类器,默认GradientBoostingClassifier()

  • verbose (bool) – 是否开启日志,默认开启

fit(train_tsdatasets: List[TSDataset], train_labels: ndarray) None[源代码]
参数
  • train_tsdatasets (TSDataset) – 训练集,需要是一个List[TSDataset]

  • train_labels – 标签,长度跟训练集长度相同

predict(tsdatasets: List[TSDataset]) ndarray[源代码]

预测

参数

tsdataset_list (TSDataset) – 预测数据

predict_proba(tsdatasets: List[TSDataset]) ndarray[源代码]

预测概率

参数

tsdataset_list (TSDataset) – 预测数据

save(path: str, repr_classifier_file_name: str = 'repr-classifier-partial.pkl') None[源代码]

保存模型

参数
  • path (str) – 保存路径

  • ensemble_file_name (str) – 保存文件名

static load(path: str, repr_classifier_file_name: str = 'repr-classifier-partial.pkl') ReprClassifier[源代码]

加载模型

参数
  • path (str) – 加载路径

  • ensemble_file_name (str) – 保存文件名

返回

加载的模型