Source code for paddlets.metrics.utils

#!/usr/bin/env python3
# -*- coding: UTF-8 -*-

from typing import Callable
import functools

import numpy as np

from paddlets.metrics.base import Metric


[docs]def ensure_2d(func) -> Callable[..., float]: """A decorator, used for ensuring that the parameter of the wrapped function is a 2-dimentional tensor so that it fits sklearn.metrics. Args: func(Callable[..., float]): Core function. Returns: Callable[..., float]: Wrapped core function. """ @functools.wraps(func) def wrapper( obj: Metric, y_true: np.ndarray, y_score: np.ndarray, **kwargs ) -> float: """Core processing logic. Args: obj(Metric): Metirc instance. y_true(np.ndarray): Ground truth (correct) labels. y_score(np.ndarray): Estimated target values. Returns: float: metric. """ batch_nd_true, batch_nd_score = y_true.shape[0], y_score.shape[0] y_true = np.reshape(y_true, (batch_nd_true, -1)) y_score = np.reshape(y_score, (batch_nd_score, -1)) return func(obj, y_true, y_score, **kwargs) return wrapper