Xai is a model interpretation module, which can explain how complex model prediction results are formed, and help users quickly understand the relationship between input and output. At present, the xai module is divided into two sub-modules: ante_hoc and post_hoc. The former provides interpretability based on the designed model network structure, while the latter has nothing to do with the model network, but interprets the original model through the proxy model.

1. Prepare Data

1.1. Get Data and import library

Get PaddleTS inner-build datasets.

import numpy as np
import pandas as pd
import paddle
import matplotlib.pyplot as plt


from paddlets import TSDataset, TimeSeries
from paddlets.xai.post_hoc.shap_explainer import ShapExplainer
from paddlets.datasets.repository import get_dataset

tsdataset = get_dataset("ECL")

1.2. Split Data

Split dataset into train/test/valid.

ts_known = TimeSeries(data[['MT_001', ]], freq='1H').copy()
ts_cols = data.columns
keep_cols = ['MT_000', ]
remove_cols = []
for col, types in ts_cols.items():
    if (types is 'target'):
    if (col not in keep_cols):


data, _ = data.split('2014-06-30')
train_data, test_data = data.split('2014-06-15')
train_data, val_data = train_data.split('2014-06-01')

2. Prepare model parameters

Prepare base model parameters.

in_chunk_len = 24
out_chunk_len = 24
skip_chunk_len = 0
sampling_stride = 24
max_epochs = 10
patience = 5

3. Construct and Fitting

Construct and Fitting pipeline

from paddlets.models.forecasting import NBEATSModel
from paddlets.transform import StandardScaler
from paddlets.pipeline.pipeline import Pipeline

pipeline_list = [(StandardScaler, {}),
                 (NBEATSModel, {'in_chunk_len': in_chunk_len,
                                'out_chunk_len': out_chunk_len,
                                'skip_chunk_len': skip_chunk_len,
                                'max_epochs': max_epochs,
                                'patience': patience})
pipe = Pipeline(pipeline_list)
pipe.fit(train_data, val_data)

4. Xai

Interpretation of prediction results based on kernel shap method.

4.1. Initialize the interpreter

ShapExplainer: Help users realize the link bridge between the PaddleTS model and the shap interpreter, and better help users understand the nature of the output results.

se = ShapExplainer(pipe, train_data, background_sample_number=100, keep_index=True, use_paddleloader=False)

4.2. Explain test sample

ShapExplainer.explain: Help users calculate samples that need to be interpreted, and give feature contribution

shap_value = se.explain(test_data_fea, nsamples=100)

4.3. Feature contribution figure

ShapExplainer.force_plot: Use additive layers to show sample data time points that require interpretation. In the display results, lag_0 represents the last moment of in_chunk_len, and lag_1 represents the first moment of out_chunk_len

se.force_plot(out_chunk_indice=[5, ], sample_index=0, contribution_threshold=0.05)


4.4. Feature importance display

ShapExplainer.summary_plot: Calculate and sort the feature contribution value for the specified time point to be predicted.

se.summary_plot(out_chunk_indice=[5, ], sample_index=0)


4.5. Multi-dimensional output contribution value display—feature variable

Note: The following shows the feature contribution of each feature variable at all input time steps and all output time steps



4.6. Multi-dimensional output contribution value display—input time step

Note: The following shows the feature contribution of each input time step on all features and all output time steps.



4.7. Multi-dimensional output contribution value display—Input time step and output time step

Note: The following shows the feature contribution of each input time step and each output time step on all feature variables



4.8. Multi-dimensional output contribution value display—Feature variables and output time steps

Note: The following shows the feature contribution of each feature variable and each output time step on all input time steps



4.9. Multi-dimensional output contribution value display—Feature variables and input time steps

Note: The following shows the feature contribution of each input time step and each variable over all output time steps

se.plot(method='IV', figsize=(30, 5))
