Source code for paddlets.metrics.base

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

from typing import Any, List, Tuple, Dict
from abc import ABC, abstractmethod

import numpy as np
import pandas as pd
from paddlets import TimeSeries, TSDataset
from paddlets.logger import Logger, raise_if_not, raise_if, raise_log

logger = Logger(__name__)


[docs]class Metric(ABC): """Abstract base class used to build new Metric. Args: mode(str): Supported metric modes, only normal and prob are valid values. Set to normal for non-probability use cases, set to prob for probability use cases. kwargs: Keyword parameters of specific metric functions. """ def __init__(self, mode: str="normal", **kwargs): self._kwargs = kwargs raise_if_not(mode in {"normal", "prob"}, f"Metric mode should be one of {{`normal`, `prob`}}, got `{mode}`.") self._mode = mode def _build_metrics_data( self, tsdataset_true: "TSDataset", tsdataset_pred: "TSDataset", ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: """Convert TSDataset of normal mode to ndarray. Args: tsdataset_true(TSDataset): TSDataset containing Ground truth (correct) target values. tsdataset_pred(TSDataset): TSDataset containing Estimated target values. Returns: Dict[str, Tuple[np.ndarray, np.ndarray]]: Dict of tuple, key is the name of target, and value is tuple type (y_true, y_score). Raises: ValueError. """ target_true = tsdataset_true.get_target().sort_columns() target_pred = tsdataset_pred.get_target().sort_columns() raise_if( target_true is None or target_pred is None, "TSDataset target is None!" ) raise_if_not( len(target_true.columns) == len(target_pred.columns), "In `normal` mode, only point forecasting data is supported!" ) raise_if_not( (target_true.columns == target_pred.columns).all(), "tsdataset true's and pred's columns are not the same!" ) target_pred = TimeSeries( target_pred.data.reindex(target_true.time_index), target_true.freq ) for column in target_pred.columns: raise_if( target_pred.data[column].isna().all(), "tsdataset true's and pred's time_index do not match!" ) res = {} for target in target_true.columns: res[target] = (target_true.data[target].to_numpy(), target_pred.data[target].to_numpy()) return res def _build_prob_metrics_data( self, tsdataset_true: "TSDataset", tsdataset_pred: "TSDataset", data_type: str, ) -> Dict[str, Tuple[np.ndarray, np.ndarray]]: """Convert TSDataset of prob mode to ndarray. Args: tsdataset_true(TSDataset): TSDataset containing ground truth (correct) target values. tsdataset_pred(TSDataset): TSDataset containing estimated target values. Returns: Dict[str, Tuple[np.ndarray, np.ndarray]]: Dict of tuple, key is the name of target, and value is tuple type. Raises: ValueError. """ target_true = tsdataset_true.get_target().sort_columns() target_pred = tsdataset_pred.get_target().sort_columns() # check validation raise_if( target_true is None or target_pred is None, "TSDataset target is None!") # check columns target_set = set(target_true.columns) pred_target_set = set([col.rsplit("@", 1)[0] for col in target_pred.columns]) raise_if_not(target_set == pred_target_set, "Prediction is not coherent with ground truth.") target_pred.reindex(target_true.time_index) for column in target_pred.columns: raise_if( target_pred.data[column].isna().all(), "tsdataset true's and pred's time_index do not match!") target_true = target_true.to_dataframe() target_pred = target_pred.to_dataframe() res = {} target_pred_names = target_pred.columns for target_name in target_true.columns: cur_pred_target_names = [x for x in target_pred_names if x.rsplit("@", 1)[0] == target_name] target_pred_cur = target_pred[cur_pred_target_names] if data_type == "quantile": res[target_name] = (target_true[target_name].to_numpy(), target_pred_cur.to_numpy()) else: # data_type: "point" target_pred_cur_median = np.median(target_pred_cur.to_numpy(), axis = -1) res[target_name] = (target_true[target_name].to_numpy(), target_pred_cur_median) return res
[docs] @abstractmethod def metric_fn( self, y_true: np.ndarray, y_pred: np.ndarray, ) -> float: """ Compute metric's value from ndarray. Args: y_true(np.ndarray): Ground truth (correct) target values. y_pred(np,ndarray): Estimated target values. Returns: float: Computed metric value. Raises: ValueError. """ pass
def __call__( self, tsdataset_true: "TSDataset", tsdataset_pred: "TSDataset", )-> Dict[str, float]: """ Compute metric's value from TSDataset. Args: tsdataset_true(TSDataset): TSDataset containing ground truth (correct) target values. tsdataset_pred(TSDataset): TSDataset containing estimated target values. Returns: Dict[str, float]: Dict of metrics. key is the name of target, and value is specific metric value. Raises: ValueError. """ if self._mode == "normal": res_array = self._build_metrics_data(tsdataset_true, tsdataset_pred) else: # "prob" res_array = self._build_prob_metrics_data(tsdataset_true, tsdataset_pred, self._TYPE) res = {} for target, value in res_array.items(): res[target] = self.metric_fn(value[0], value[1]) return res
[docs] @classmethod def get_metrics_by_names(cls, names: List[str]) -> List["Metric"]: """Get list of metric classes. Args: names(List[str]): List of metric names. Returns: List[Metric]: List of metric classes. """ available_metrics = cls.__subclasses__() available_names = [metric._NAME for metric in available_metrics] metrics = [] for name in names: assert (name in available_names ), f"{name} is not available, choose in {available_names}" idx = available_names.index(name) metric = available_metrics[idx]() metrics.append(metric) return metrics