Source code for paddlets.models.classify.base

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

import abc

from typing import List, Optional
import numpy as np

from paddlets.datasets import TSDataset


[docs]class BaseClassifier(abc.ABC): """ Base class for all classifier. """ def __init__(self): pass
[docs] @abc.abstractmethod def fit( self, train_tsdatasets: List[TSDataset], train_labels: np.ndarray, valid_tsdatasets: Optional[List[TSDataset]] = None, valid_labels: Optional[np.ndarray] = None ): """ Fit a BaseClassifier instance. Any non-abstract classes inherited from this class should implement this method. Args: train_tsdataset(TSDataset): Train set. train_labels:(np.ndarray) : The train data class labels valid_tsdataset(TSDataset|None): Eval set, used for early stopping. valid_labels:(np.ndarray) : The valid data class labels """ pass
[docs] @abc.abstractmethod def predict( self, tsdatasets: List[TSDataset] ) -> np.ndarray: """ Predict labels. Results are output as ndarray. Args: tsdataset(List[TSDataset]) : Data to be predicted. Returns: np.ndarray. """ pass
[docs] @abc.abstractmethod def predict_proba( self, tsdatasets: List[TSDataset] ) -> np.ndarray: """ Find probability estimates for each class for all cases. Results are output as ndarray. Args: tsdataset(List[TSDataset]) : Data to be predicted. Returns: np.ndarray. """ pass
[docs] @abc.abstractmethod def save(self, path: str) -> None: """ Saves a BaseClassifier instance to a disk file. Any non-abstract classes inherited from this class should implement this method. Args: path(str): A path string containing a model file name. """ pass
[docs] @staticmethod @abc.abstractmethod def load(path: str) -> "BaseClassifier": """ Loads a :class:`~/paddlets.models.classify.base.BaseClassifier` instance from a file. Any non-abstract classes inherited from this class should implement this method. Args: path(str): A path string containing a model file name. Returns: BaseClassifier: A loaded model. """ pass