#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
from typing import Dict, List, Union, Callable, Optional, Any
from IPython.display import display
from tqdm import tqdm
import math
import numpy as np
import pandas as pd
from matplotlib import rcParams
import matplotlib.pyplot as plt
from matplotlib.ticker import FuncFormatter
from paddlets.utils.utils import check_model_fitted
from paddlets.models.forecasting import TFTModel
from paddlets.logger import Logger, raise_if
from paddlets.datasets import TSDataset, TimeSeries
logger = Logger(__name__)
rcParams.update({'figure.autolayout': True,
'figure.figsize': [10, 5],
'font.size': 10})
[docs]class TFTExplainer(TFTModel):
"""
Inherit TFT, and implement an explainer, which provides display of the explanation result.
"""
def __init__(self, *args, **kwargs):
"""
Init base TFT.
"""
super().__init__(*args, **kwargs)
def _update_fit_params(
self,
train_tsdataset: List[TSDataset],
valid_tsdataset: Optional[List[TSDataset]] = None
) -> Dict[str, Any]:
"""
Rewrite base `_update_fit_params` function to add more data members for explanation function.
Args:
train_tsdataset(List[TSDataset]): list of train dataset
valid_tsdataset(List[TSDataset], optional): list of validation dataset
Returns:
Dict[str, Any]: model parameters
"""
fit_params = super()._update_fit_params(train_tsdataset, valid_tsdataset)
self.static_cols = self.static_num_cols + self.static_cat_cols
self.history_cols = self.target_cols + self.known_num_cols + \
self.observed_num_cols + self.known_cat_cols + self.observed_cat_cols
self.future_cols = self.known_num_cols + self.known_cat_cols
self.mapping = {
'Static Weights': {'arr_key': 'static_weights', 'feat_names': self.static_cols},
'Historical Weights': {'arr_key': 'historical_selection_weights', 'feat_names': self.history_cols},
'Future Weights': {'arr_key': 'future_selection_weights', 'feat_names': self.future_cols},
}
return fit_params
def _display_explanations(
self,
explain_data: Dict[str, np.ndarray],
weights_prctile: Optional[List[float]]=[10, 50, 90],
observation_index: int=0,
horizons: Union[List[int], int]=[1,5,10],
top_feature_num: Optional[int]=20,
unit: Optional[str] = 'Units'
):
"""
Main visualization logic, which contains selection weights and attention scores.
Args:
explain_data(Dict[str, np.ndarray]): A dictionary of numpy arrays containing the explanation outputs of the model for a set of observations.
weights_prctile(List[float]): A list of percentile to compute as a distribution describer for the scores.
observation_index(int): The index with the dataset, corresponding to the observation for which the visualization will be generated.
horizons(List[int], Optional): A list horizon, specified in time-steps units, for which the statistics will be computed.
top_feature_num(int, Optional): An integer specifying the quantity of the top weighted features to display.
unit(str, Optional): The units associated with the time-steps. This variable is used for labeling the corresponding axes.
"""
if not isinstance(weights_prctile, list):
weights_prctile = [weights_prctile]
# ========================
# Selection Weights
# ========================
self._display_selection_weights_stats(outputs_dict=explain_data,
prctiles=weights_prctile,
mapping=self.mapping
)
self._display_sample_wise_selection_stats(weights_arr=explain_data['static_weights'],
observation_index=observation_index,
feature_names=self.static_cols,
top_n=top_feature_num,
title='Static Features')
self._display_sample_wise_selection_stats(weights_arr=explain_data['historical_selection_weights'],
observation_index=observation_index,
feature_names=self.history_cols,
top_n=top_feature_num,
title='Historical Features',
rank_stepwise=True)
self._display_sample_wise_selection_stats(weights_arr=explain_data['future_selection_weights'],
observation_index=observation_index,
feature_names=self.future_cols,
top_n=top_feature_num,
title='Future Features',
historical=False,
rank_stepwise=False)
# ========================
# Attention Scores
# ========================
if not isinstance(horizons, list):
horizons = [horizons]
# One step ahead
if len(weights_prctile) > 1: # only for backtest scenario
for horizon in horizons:
self._display_attention_scores(attention_scores=explain_data['attention_scores'],
horizons=horizon,
prctiles=weights_prctile,
unit=unit)
# Multihorizon Attention
# for prediction scenario and backtest scenario
for prctile in weights_prctile:
self._display_attention_scores(attention_scores=explain_data['attention_scores'],
horizons=horizons,
prctiles=prctile,
unit=unit)
# Single specific sample
if explain_data['attention_scores'].shape[0] > 1:
self._display_sample_wise_attention_scores(attention_scores=explain_data['attention_scores'],
observation_index=observation_index,
horizons=horizons,
unit=unit)
def _check_horizons(
self,
horizons: Union[List[int], int],
out_chunk_len: int,
):
"""
check validation of horizons for backtest and prediction.
Args:
horizons(Union[List[int], int]): List of horizon to be explained.
out_chunk_len(int): The size of the model's forecasting time steps.
"""
if isinstance(horizons, int):
horizons = [horizons]
for horizon in horizons:
raise_if(horizon > out_chunk_len,
f"all horizons should be no bigger than model's `out_chunk_len`: {out_chunk_len}, got {horizons}.")
def _check_backtest_params(
self,
target_length: int,
in_chunk_len: int,
skip_chunk_len: int,
start: Union[pd.Timestamp, int, str ,float] = None,
):
"""
For backtest's explanation, check validation of parameters.
Args:
target_length(int): The length of target.
in_chunk_len(int): The size of the loopback window, i.e., the number of time steps feed to the model.
skip_chunk_len(int): Optional, the number of time steps between in_chunk and out_chunk for a single sample.
start(Union[pd.Timestamp, int, str ,float]): The first prediction time, at which a prediction is computed for a future time.
"""
# check whether model fitted or not.
check_model_fitted(self)
# start time should no less than in_chunk_len + skip_chunk_len
raise_if(start < in_chunk_len + skip_chunk_len,
f"Parameter 'start' value should >= in_chunk_len {in_chunk_len} + skip_chunk_len {skip_chunk_len}")
# start time should no bigger than target length
raise_if(start > target_length,
f"Parameter 'start' value should not exceed data target_len {target_length}")
# if skip_chunk_len !=0, prediction will start from start + skip_chunk_len.
if skip_chunk_len != 0:
logger.info(f"model.skip_chunk_len is {skip_chunk_len}, \
backtest will start at index {start + skip_chunk_len} (start + skip_chunk_len)")
[docs] def explain_backtest(
self,
data: TSDataset,
start: Union[pd.Timestamp, int, str ,float] = None,
observation_index: Optional[int]=0,
horizons: Union[List[int], int]=[1],
unit: Optional[str] = 'Units',
display: Optional[bool] = True
):
"""
Explain backtest data, the backtest logic is a simplied version of `utils.backtest` by setting `predict_window` and `stride` as `out_chunk_len`.
Args:
data(TSDataset): The TSdataset used for successively generating explanation result and visualizing.
start(Union[pd.Timestamp, int, str ,float]): The first prediction time, at which a prediction is computed for a future time.
observation_index(int, Optional): The index with the dataset, corresponding to the observation for which the visualization will be generated.
horizons(Union[List[int], int]): A list horizon, specified in time-steps units, for which the statistics will be computed.
unit(str, Optional): The units associated with the time-steps. This variable is used for labeling the corresponding axes.
display(bool, Optional): Whether to display the explanation results.
Returns:
Dict[str, np.ndarray]: Aggregated explanation data predicted by the model.
"""
predicts_agg = {}
data = data.copy()
predict_window = self._out_chunk_len
all_target = data.get_target()
all_observe = data.get_observed_cov() if data.get_observed_cov() else None
if start is None:
start = self._in_chunk_len + self._skip_chunk_len
start = all_target.get_index_at_point(start)
# check horizons
self._check_horizons(horizons, self._out_chunk_len)
# check backtest parameters
self._check_backtest_params(len(all_target), self._in_chunk_len, self._skip_chunk_len, start)
predict_rounds = math.ceil((len(all_target) - start) / self._sampling_stride)
index = start - self._skip_chunk_len
# iterative prediction
for _ in tqdm(range(predict_rounds), desc="Backtest Progress"):
data._target, rest = all_target.split(index)
data._observed_cov, _ = all_observe.split(index) if all_observe else (None, None)
rest_len = len(rest)
if rest_len < predict_window + self._skip_chunk_len:
if data.known_cov is not None:
target_end_time = data._target.end_time
known_index = data.known_cov.get_index_at_point(target_end_time)
if len(data.known_cov) - known_index - 1 < predict_window + self._skip_chunk_len:
break
predict_window = rest_len - self._skip_chunk_len
output = self.predict_interpretable(data)
for key, array in output.items():
predicts_agg.setdefault(key, []).append(array)
# step to next sample
index = index + self._sampling_stride
# aggregates all sample's explanation results
outputs = dict()
for k in list(predicts_agg.keys()):
outputs[k] = np.concatenate(predicts_agg[k], axis=0)
# visualization
if display:
self._display_explanations(explain_data=outputs,
weights_prctile=self._q_points.tolist(),
observation_index=observation_index,
horizons=horizons,
unit=unit)
return outputs
[docs] def explain_prediction(
self,
data: TSDataset,
horizons: Union[List[int], int]=[1],
unit: Optional[str] = 'Units',
display: Optional[bool] = True
):
"""
Explain prediction data, in cases of single sample prediction.
Args:
data(TSDataset): The TSdataset used for predicting explanation result and visualizing.
horizons(Union[List[int], int]): A list or a single horizon, specified in time-steps units, for which the statistics will be computed.
unit(str, Optional): The units associated with the time-steps. This variable is used for labeling the corresponding axes.
display(bool, Optional): Whether to display the explanation results.
Returns:
Dict[str, np.ndarray]: Explanation data predicted by the model.
"""
# check horizons
self._check_horizons(horizons, self._out_chunk_len)
# explained result for single prediction
explain_data = self.predict_interpretable(data)
if display:
# for prediction scenario, only one sample generated, set percentile as 50.
self._display_explanations(explain_data=explain_data,
weights_prctile=50,
observation_index=0,
horizons=horizons,
unit=unit)
return explain_data
def _aggregate_weights(
self,
output_arr: np.ndarray,
prctiles: List[float],
feat_names: List[str]) -> pd.DataFrame:
"""
Implements a utility function for aggregating selection weights for a set (array) of observations,
whether these selection weights are associated with the static input attributes, or with a set of temporal selection
weights.
The aggregation of the weights is performed through the computation of several percentiles (provided by the caller)
for describing the distribution of the weights, for each attribute.
Args:
output_arr(np.ndarray): A 2D or 3D array containing the selection weights output by the model. A 3D tensor will imply selection weights associated with temporal inputs.
prctiles(List[float]): A list of percentiles according to which the distribution of selection weights will be described.
feat_names(List[str]):A list of strings associated with the relevant attributes (according to the their order).
Returns:
agg_df(pd.DataFrame): A pandas dataframe, indexed with the relevant feature names, containing the aggregation of selection weights.
"""
prctiles_agg = [] # a list to contain the computation for each percentile
for q in prctiles: # for each of the provided percentile
# infer whether the provided weights are associated with a temporal input channel
if len(output_arr.shape) > 2:
# lose the temporal dimension and then describe the distribution of weights
flatten_time = output_arr.reshape(-1, output_arr.shape[-1])
else: # if static - take as is
flatten_time = output_arr
# accumulate
prctiles_agg.append(np.percentile(flatten_time, q=q, axis=0))
# combine the computations and index according to feature names
agg_df = pd.DataFrame({prctile: aggs for prctile, aggs in zip(prctiles, prctiles_agg)})
agg_df.index = feat_names
return agg_df
def _display_selection_weights_stats(
self,
outputs_dict: Dict[str, np.ndarray],
prctiles: List[float],
mapping: Dict,
):
"""
Implements a utility function for displaying the selection weights statistics of multiple input channels according
to the outputs provided by the model for a set of input observations.
It requires a mapping which specifies which output key corresponds to each input channel, and the associated list
of attributes.
Args:
outputs_dict(Dict[str,np.ndarray]): A dictionary of numpy arrays containing the outputs of the model for a set of observations.
prctiles(List[float]): A list of percentiles according to which the distribution of selection weights will be described.
mapping(Dict): A dictionary specifying the output key corresponding to which input channel and the associated feature names.
sort_by(Optional[float]): The percentile according to which the weights statistics will be sorted before displaying (Must be included as
part of ``prctiles``).
"""
sort_by = 50.0 if 50.0 in prctiles else prctiles[0]
# for each input channel included in the mapping
for name, config in mapping.items():
if not config["feat_names"]:
continue
# perform weight aggregation according to the provided configuration
weights_agg = self._aggregate_weights(output_arr=outputs_dict[config['arr_key']],
prctiles=prctiles,
feat_names=config['feat_names'])
print(name)
print('=========')
# display the computed statistics, sorted, and color highlighted according to the value.
display(weights_agg.sort_values([sort_by], ascending=False).style.background_gradient(cmap='viridis'))
def _display_attention_scores(
self,
attention_scores: np.ndarray,
horizons: Union[int, List[int]],
prctiles: Union[float, List[float]],
unit: Optional[str] = 'Units'
):
"""
Implements a utility function for displaying the statistics of attention scores according
to the outputs provided by the model for a set of input observations.
The statistics of the scores will be described using specified percentiles, and for specified horizons.
Args:
attention_scores(np.ndarray): A numpy array containing the attention scores for the relevant dataset.
horizons(Union[int, List[int]]): A list or a single horizon, specified in time-steps units, for which the statistics will be computed. If more than one horizon was configured, then only a single percentile computation will be allowed.
prctiles(Union[int, List[int]]): A list or a single percentile to compute as a distribution describer for the scores. If more than percentile was configured, then only a single horizon will be allowed.
unit(Optional[str]): The units associated with the time-steps. This variable is used for labeling the corresponding axes.
"""
# if any of ``horizons`` or ``prctiles`` is provided as int, transform into a list.
if not isinstance(horizons, list):
horizons = [horizons]
if not isinstance(prctiles, list):
prctiles = [prctiles]
# make sure only maximum one of ``horizons`` and ``prctiles`` has more than one element.
assert len(prctiles) == 1 or len(horizons) == 1
# compute the configured percentiles of the attention scores, for each percentile separately
attn_stats = {}
for prctile in prctiles:
attn_stats[prctile] = np.percentile(attention_scores, q=prctile, axis=0)
#fig, ax = plt.subplots(figsize=(10, 5))
fig, ax = plt.subplots()
if len(prctiles) == 1: # in case only a single percentile was configured
relevant_prctile = prctiles[0]
title = f"Multi-Step - Attention ({relevant_prctile}% Percentile)"
scores_percentile = attn_stats[relevant_prctile]
for horizon in horizons: # a single line for each horizon
# infer the corresponding x_axis according to the shape of the scores array
siz = scores_percentile.shape
x_axis = np.arange(siz[0] - siz[1], siz[0])
ax.plot(x_axis, scores_percentile[horizon - 1], lw=1, label=f"t + {horizon} scores", marker='o')
else:
title = f"{horizons[0]} Steps Ahead - Attention Scores"
for prctile, scores_percentile in attn_stats.items(): # for each percentile
# infer the corresponding x_axis according to the shape of the scores array
siz = scores_percentile.shape
x_axis = np.arange(siz[0] - siz[1], siz[0])
ax.plot(x_axis, scores_percentile[horizons[0] - 1], lw=1, label=f"{prctile}%", marker='o')
ax.axvline(x=0, lw=1, color='r', linestyle='--')
ax.grid(True)
ax.set_xlabel(f"Relative Time-step [{unit}]")
ax.set_ylabel('Attention Scores')
ax.set_title(title)
ax.legend()
plt.show(block=False)
def _display_sample_wise_attention_scores(
self,
attention_scores: np.ndarray,
observation_index: int,
horizons: Union[int, List[int]],
unit: Optional[str] = None):
"""
Implements a utility function for displaying, on a single observation level,
the attention scores output by the model, for, possibly, a multitude of horizons.
Args:
attention_scores(np.ndarray): A numpy array containing the attention scores for the relevant dataset.
observation_index(int): The index with the dataset, corresponding to the observation for which the visualization will be generated.
horizons(Union[int, List[int]]): A list or a single horizon, specified in time-steps units, for which the scores will be displayed.
unit(Optional[str]): The units associated with the time-steps. This variable is used for labeling the corresponding axes.
"""
# if ``horizons`` is provided as int, transform into a list.
if isinstance(horizons, int):
horizons = [horizons]
# take the relevant record from the provided array, using the specified index
sample_attn_scores = attention_scores[observation_index, ...]
fig, ax = plt.subplots()
# infer the corresponding x_axis according to the shape of the scores array
attn_shape = sample_attn_scores.shape
x_axis = np.arange(attn_shape[0] - attn_shape[1], attn_shape[0])
# for each horizon, plot the associated attention score signal for all the steps
for step in horizons:
ax.plot(x_axis, sample_attn_scores[step - 1], marker='o', lw=3, label=f"t+{step}")
ax.axvline(x=-0.5, lw=1, color='k', linestyle='--')
ax.grid(True)
ax.legend()
ax.set_xlabel('Relative Time-Step ' + (f"[{unit}]" if unit else ""))
ax.set_ylabel('Attention Score')
ax.set_title('Attention Mechanism Scores - Per Horizon')
plt.show(block=False)
def _display_sample_wise_selection_stats(
self,
weights_arr: np.ndarray,
observation_index: int,
feature_names: List[str],
top_n: Optional[int] = None,
title: Optional[str] = '',
historical: Optional[bool] = True,
rank_stepwise: Optional[bool] = False
):
"""
Implements a utility function for displaying, on a single observation level, the selection weights output by the
model. This function can handle selection weights of both temporal input channels and static input channels.
Args:
weights_arr(np.ndarray): A 2D or 3D array containing the selection weights output by the model. A 3D tensor will implies selection weights associated with temporal inputs.
observation_index(int): The index with the dataset, corresponding to the observation for which the visualization will be generated.
feature_names(List[str]): A list of strings associated with the relevant attributes (according to the their order).
top_n(Optional[int]): An integer specifying the quantity of the top weighted features to display.
title(Optional[str]): A string which will be used when creating the title for the visualization.
historical(Optional[bool]): Specifies whether the corresponding input channel contains historical data or future data. Relevant only for temporal input channels, and used for display purposes.
rank_stepwise(Optional[bool]): Specifies whether to rank the features according to their weights, on each time-step separately, or simply display the raw selection weights output by the model. Relevant only for temporal input channels, and used for display purposes.
"""
# a-priori assume non-temporal input channel
num_temporal_steps = None
# infer number of attributes according to the shape of the weights array
weights_shape = weights_arr.shape
num_features = weights_shape[-1]
if num_features <= 1:
# no feature, return
return
# infer whether the input channel is temporal or not
is_temporal: bool = len(weights_shape) > 2
# bound maximal number of features to display by the total amount of features available (in case provided)
top_n = min(num_features, top_n) if top_n else num_features
# take the relevant record from the provided array, using the specified index
sample_weights = weights_arr[observation_index, ...]
if is_temporal:
# infer number of temporal steps
num_temporal_steps = weights_shape[1]
# aggregate the weights (by averaging) across all the time-steps
sample_weights_trans = sample_weights.T
weights_df = pd.DataFrame({'weight': sample_weights_trans.mean(axis=1)}, index=feature_names)
else:
# in case the input channel is not temporal, just use the weights as is
weights_df = pd.DataFrame({'weight': sample_weights}, index=feature_names)
# ========================
# Aggregative Barplot
# ========================
#fig, ax = plt.subplots(figsize=(10, 5))
fig, ax = plt.subplots()
weights_df.sort_values('weight', ascending=False).iloc[:top_n].plot.bar(ax=ax)
for tick in ax.xaxis.get_major_ticks():
tick.label.set_fontsize(11)
tick.label.set_rotation(45)
ax.grid(True)
ax.set_xlabel('Feature Name')
ax.set_ylabel('Selection Weight')
ax.set_title(title + (" - " if title != "" else "") + \
f"Selection Weights " + ("Aggregation " if is_temporal else "") + \
(f"- Top {top_n}" if top_n < num_features else ""))
plt.show(block=False)
if is_temporal:
# ========================
# Temporal Display
# ========================
# infer the order of the features, according to the average selection weight across time
order = sample_weights_trans.mean(axis=1).argsort()[::-1]
# order the weights sequences as well as their names accordingly
ordered_weights = sample_weights_trans[order]
ordered_names = [feature_names[i] for i in order.tolist()]
if rank_stepwise:
# the weights are now considered to be the ranking after ordering the features in each time-step separately
ordered_weights = np.argsort(ordered_weights, axis=0)
#fig, ax = plt.subplots(figsize=(9, 6))
fig, ax = plt.subplots()
# create a corresponding x-axis, going forward/backwards, depending on the configuration
if historical:
map_x = {idx: val for idx, val in enumerate(np.arange(0 - num_temporal_steps, 1))}
else:
map_x = {idx: val for idx, val in enumerate(np.arange(1, num_temporal_steps + 1))}
def format_fn(tick_val, tick_pos):
if int(tick_val) in map_x:
return map_x[int(tick_val)]
else:
return ''
# display the weights as images
im = ax.pcolor(ordered_weights, edgecolors='gray', linewidths=2)
# feature names displayed to the left
ax.yaxis.set_ticks(np.arange(len(ordered_names)))
ax.set_yticklabels(ordered_names)
ax2 = ax.twiny()
ax2.set_xticks([])
ax2.xaxis.set_ticks_position('top')
ax.set_xlabel(('Historical' if historical else 'Future') + ' Time-Steps')
ax2.set_xlabel(('Historical' if historical else 'Future') + ' Time-Steps')
ax.xaxis.set_major_formatter(FuncFormatter(format_fn))
fig.colorbar(im, orientation="horizontal", pad=0.05, ax=ax2)
plt.show(block=False)