Source code for paddlets.utils.utils

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

from numbers import Integral
import uuid
import hashlib

from inspect import isclass
import pandas as pd

from paddlets.models.base import Trainable
from paddlets.logger import raise_if_not, raise_if, raise_log
from paddlets.datasets.tsdataset import TSDataset

try:
    from paddlets.automl import AutoTS
except Exception as e:
    AutoTS = None

[docs]def check_model_fitted(model: Trainable, msg: str = None): """ check if model has fitted, Raise Exception if not fitted Args: model(Trainable): model instance. msg(str): str, default=None The default error message is, "This %(name)s instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator." For custom messages if "%(name)s" is present in the message string, it is substituted for the estimator name. Eg. : "Estimator, %(name)s, must be fitted before sparsifying". Returns: None Raise: ValueError """ from paddlets.pipeline import Pipeline from paddlets.models.forecasting.ml.ml_base import MLBaseModel from paddlets.models.forecasting.dl.paddle_base import PaddleBaseModel #不需要fit的模型列表 MODEL_NEED_NO_FIT = ["ArimaModel"] if model.__class__.__name__ in MODEL_NEED_NO_FIT: return if isclass(model): raise_log(ValueError(f"{type(model).__name__}is a class, not an instance.")) if msg is None: msg = ( "This %(name)s instance is not fitted yet. Call 'fit' with " "appropriate arguments before using this estimator." ) if not isinstance(model, Trainable): raise_log(ValueError(f"{type(model).__name__} is not a Trainable Object.")) fitted = False # PipeLine if isinstance(model, Pipeline): fitted = model._fitted # Paddle 模型 if isinstance(model, PaddleBaseModel): fitted = True if model._network else False # ML 模型 if isinstance(model, MLBaseModel): #TODO:后续如果将 self._models 提到 MLBaseModel后,这里需要同步修改为判断 self._models ,而不是 "_models" 字符串。 fitted = True if "model" in vars(model) or "_model" in vars(model) else False if AutoTS is not None and isinstance(model, AutoTS): fitted = model.is_refitted() raise_if_not(fitted, msg % {"name": type(model).__name__})
[docs]def get_uuid(prefix: str = "", suffix: str = ""): """ Get a random string of 16 characters. Args: prefix(str, optional): The prefix of the returned string. suffix(str, optional): The suffix of the returned string. Returns: str: String of 16 characters. """ digits = "01234abcdefghijklmnopqrstuvwxyz56789" new_uuid = uuid.uuid1() md = hashlib.md5() md.update(str(new_uuid).encode()) for i in md.digest(): x = (i + 128) % 34 prefix = prefix + digits[x] res = prefix + suffix if suffix is not None else prefix return res
[docs]def check_train_valid_continuity(train_data: TSDataset, valid_data: TSDataset)-> bool: """ Check if train and test TSDataset are continous Args: train_data(TSDataset): Train dataset. test_data(TSDataset): Test dataset. Return: bool: if train and test TSDataset are continous """ train_index = train_data.target.data.index valid_index = valid_data.target.data.index continuious = False if isinstance(train_index, pd.DatetimeIndex): if isinstance(valid_index, pd.DatetimeIndex): continuious = (valid_index[0] - train_index[-1] == pd.to_timedelta(train_index.freq)) elif isinstance(train_index, pd.RangeIndex): if isinstance(valid_index, pd.RangeIndex): continuious = (valid_index[0] - train_index[-1] == train_index.step) else: raise_log("Unsupport data index format") return continuious
[docs]def split_dataset(dataset: TSDataset, split_point: int) -> TSDataset: """ Split dataset (accroding to the max length) Args: dataset(TSDataset): dataset to be splited. split_point(int): split point. Return: TSDataset """ target_index = None observed_index = None known_index = None index_list = [] if dataset.target: target_index = dataset.target.data.index index_list.append(target_index) if dataset.known_cov: known_index = dataset.known_cov.data.index index_list.append(known_index) if dataset.observed_cov: observed_index = dataset.observed_cov.data.index index_list.append(observed_index) #sort to avoid wrong positions index index_list.sort(key=lambda x: x[0]) all_index = pd.concat([x.to_series() for x in index_list]).index.drop_duplicates() max_len = len(all_index) split_index = all_index[split_point-1] raise_if(split_point >= max_len, "split point should smaller than dataset length") raise_if(split_point <= 0, "split point should > 0") raise_if_not(isinstance(split_point, Integral), f"split point should be Integral type, instead of {type(split_point)}") target_pre = None target_after = None if dataset.target: if split_index < target_index[0]: target_after = dataset.target elif split_index >= target_index[-1]: target_pre = dataset.target elif split_index in target_index: if isinstance(dataset.target.data.index, pd.RangeIndex): target_pre, target_after = dataset.target.split(int((split_index - dataset.target.data.index[0]) / dataset.target.data.index.step +1)) else: target_pre, target_after = dataset.target.split(split_index) known_pre = None known_after = None if dataset.known_cov: if split_index < known_index[0]: known_after = dataset.known_cov elif split_index >= known_index[-1]: known_pre = dataset.known_cov elif split_index in known_index: if isinstance(dataset.known_cov.data.index, pd.RangeIndex): known_pre, known_after = dataset.known_cov.split(int((split_index - dataset.known_cov.data.index[0]) / dataset.known_cov.data.index.step + 1)) else: known_pre, known_after = dataset.known_cov.split(split_index) observed_pre = None observed_after = None if dataset.observed_cov: if split_index < observed_index[0]: observed_after = dataset.observed_cov elif split_index >= observed_index[-1]: observed_pre = dataset.observed_cov elif split_index in observed_index: if isinstance(dataset.observed_cov.data.index, pd.RangeIndex): observed_pre, observed_after = dataset.observed_cov.split(int(((split_index - dataset.observed_cov.data.index[0])) / dataset.observed_cov.data.index.step + 1)) else: observed_pre, observed_after = dataset.observed_cov.split(split_index) return (TSDataset(target_pre, observed_pre, known_pre, dataset.static_cov), TSDataset(target_after, observed_after, known_after, dataset.static_cov))
[docs]def get_tsdataset_max_len(dataset:TSDataset) -> int: """ Get dataset max length Args: dataset(TSDataset): dataset use to get length. Return: int """ target_index = None observed_index = None known_index = None index_list = [] if dataset.target: target_index = dataset.target.data.index index_list.append(target_index) if dataset.known_cov: known_index = dataset.known_cov.data.index index_list.append(known_index) if dataset.observed_cov: observed_index = dataset.observed_cov.data.index index_list.append(observed_index) #sort to avoid wrong positions index index_list.sort(key=lambda x: x[0]) all_index = pd.concat([x.to_series() for x in index_list]).index.drop_duplicates() return len(all_index)