paddlets.models.classify.base

class BaseClassifier[源代码]

基类:ABC

所有分类模型的基类

abstract fit(train_tsdatasets: List[TSDataset], train_labels: ndarray, valid_tsdatasets: Optional[List[TSDataset]] = None, valid_labels: Optional[ndarray] = None)[源代码]

训练一个分类模型基类的实例

任何继承自此类的非抽象子类均需实现此方法。

参数
  • train_tsdataset (TSDataset) – 训练集。

  • train_labels – 训练数据的标签

  • valid_tsdataset (TSDataset|None) – 验证集,用于早停

  • valid_labels – 验证数据的标签

abstract predict(tsdatasets: List[TSDataset]) ndarray[源代码]

预测结果。以数组方式返回

参数

tsdataset (List[TSDataset]) – 被预测数据

返回

多维数组

abstract predict_proba(tsdatasets: List[TSDataset]) ndarray[源代码]

获取每条样本在每个类别上的概率。以多维数组方式返回

参数

tsdataset (List[TSDataset]) – 被预测数据

返回

多维数组

abstract save(path: str) None[源代码]

将一个分类模型基类实例保存在磁盘文件中。

任何继承自此类的非抽象子类均需实现此方法。

参数

path (str) – 一个包含模型文件名的字符串格式的路径。

abstract static load(path: str) BaseClassifier[源代码]

从给定的文件中加载 BaseClassifier 模型实例。

任何继承自此类的非抽象子类均需实现此方法。

参数

path (str) – 一个包含模型文件名的字符串格式的路径。

返回

加载完成的模型。

返回类型

BaseClassifier