paddlets.models.classify.dl.paddle_base
- class PaddleBaseClassifier(loss_fn: ~typing.Optional[~typing.Callable[[...], ~paddle.Tensor]] = None, optimizer_fn: ~typing.Callable[[...], ~paddle.optimizer.optimizer.Optimizer] = <class 'paddle.optimizer.adam.Adam'>, optimizer_params: ~typing.Dict[str, ~typing.Any] = {'learning_rate': 0.001}, eval_metrics: ~typing.List[str] = [], callbacks: ~typing.List[~paddlets.models.common.callbacks.callbacks.Callback] = [], batch_size: int = 32, max_epochs: int = 10, verbose: int = 1, patience: int = 4, seed: ~typing.Union[None, int] = None)[source]
Bases:
BaseClassifierBase class for all paddle deep time series classify models.
- Parameters
loss_fn (Callable[..., paddle.Tensor]|None) – Loss function.
optimizer_fn (Callable[..., Optimizer]) – Optimizer algorithm.
optimizer_params (Dict[str, Any]) – Optimizer parameters.
eval_metrics (List[str]) – Evaluation metrics of model.
callbacks (List[Callback]) – Customized callback functions.
batch_size (int) – Number of samples per batch.
max_epochs (int) – Max epochs during training.
verbose (int) – Verbosity mode.
patience (int) – Number of epochs to wait for improvement before terminating.
seed (int|None) – Global random seed.
- _loss_fn
Loss function.
- Type
Callable[…, paddle.Tensor]|None
- _optimizer_fn
Optimizer algorithm.
- Type
Callable[…, Optimizer]
- _optimizer_params
Optimizer parameters.
- Type
Dict[str, Any]
- _eval_metrics
Evaluation metrics of model.
- Type
List[str]
- _batch_size
Number of samples per batch.
- Type
int
- _max_epochs
Max epochs during training.
- Type
int
- _verbose
Verbosity mode.
- Type
int
- _patience
Number of epochs to wait for improvement before terminating.
- Type
int
- _seed
Global random seed.
- Type
int|None
- _classes_(ndarray)
ndarray of class labels, possibly strings
- _n_class
number of unique labels
- Type
int
- _stop_training
- Type
bool
- _fit_params
Infer parameters by TSdataset automatically.
- Type
Dict[str, Any]
- _network
Network structure.
- Type
paddle.nn.Layer
- _optimizer
Optimizer.
- Type
Optimizer
- _metrics_names
List of metric names.
- Type
List[str]
- _metric_container_dict
Dict of metric container.
- Type
Dict[str, MetricContainer]
- _callback_container
Container holding a list of callbacks.
- Type
- check_tsdataset(tsdataset: TSDataset)[source]
Ensure the robustness of input data (consistent feature order), at the same time, check whether the data types are compatible. If not, the processing logic is as follows.
1> Floating: Convert to np.float32.
2> Missing value: Warning.
3> Other: Illegal.
- Parameters
tsdataset (TSDataset) – Data to be checked.
- fit(train_tsdatasets: List[TSDataset], train_labels: ndarray, valid_tsdatasets: Optional[List[TSDataset]] = None, valid_labels: Optional[ndarray] = None)[source]
Train a neural network stored in self._network, using train_dataloader for training data and valid_dataloader for validation.
- predict(tsdatasets: List[TSDataset]) ndarray[source]
Predict labels. the result are output as ndarray.
- Parameters
tsdataset (List[TSDataset]) – Data to be predicted.
- Returns
np.ndarray.
- predict_proba(tsdatasets: List[TSDataset]) ndarray[source]
Find probability estimates for each class for all cases.
- Parameters
tsdataset (List[TSDataset]) – Data to be predicted.
labels – (np.ndarray) : The predicted data class labels
- Returns
np.ndarray.
- score(tsdatasets: List[TSDataset], labels: ndarray) float[source]
Scores predicted labels against ground truth labels on X.
- Parameters
tsdataset (List[TSDataset]) – Data to be predicted.
labels – (np.ndarray) : The predicted data class labels
- Returns
float, accuracy score of predict(X) vs y
- save(path: str) None[source]
Saves a PaddleBaseClassifier instance to a disk file.
- Parameters
path (str) – A path string containing a model file name.
- Raises
ValueError –
- static load(path: str) PaddleBaseClassifier[source]
Loads a PaddleBaseClassifier from a file.
- Parameters
path (str) – A path string containing a model file name.
- Returns
the loaded PaddleBaseClassifier instance.
- Return type