Skip to content

Commit 3afedc0

Browse files
authored
Expose partial dependence plot in bindings (#133)
* set normalize_weights to False as in Java impl * FAI-842 - draft pdp impl * FAI-842 - improved code, pylint checks fixed * FAI-842 - improved plot() * FAI-842 - lint checks * FAI-842 - reformatting * FAI-842 - added unit tests, plot working with non-numeric data * FAI-842 - made PredictionProviderMetadata private
1 parent c3a67da commit 3afedc0

File tree

3 files changed

+263
-0
lines changed

3 files changed

+263
-0
lines changed

src/trustyai/explainers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .counterfactuals import CounterfactualResult, CounterfactualExplainer
44
from .lime import LimeExplainer, LimeResults
55
from .shap import SHAPExplainer, SHAPResults, BackgroundGenerator
6+
from .pdp import PDPExplainer

src/trustyai/explainers/pdp.py

Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
"""Explainers.pdp module"""
2+
3+
import math
4+
import matplotlib.pyplot as plt
5+
import pandas as pd
6+
from pandas.io.formats.style import Styler
7+
8+
from jpype import (
9+
JImplements,
10+
JOverride,
11+
)
12+
13+
# pylint: disable = import-error
14+
from org.kie.trustyai.explainability.global_ import pdp
15+
16+
# pylint: disable = import-error
17+
from org.kie.trustyai.explainability.model import (
18+
PredictionProvider,
19+
PredictionInputsDataDistribution,
20+
PredictionOutput,
21+
Output,
22+
Type,
23+
Value,
24+
)
25+
26+
from trustyai.utils.data_conversions import ManyInputsUnionType, many_inputs_convert
27+
28+
from .explanation_results import ExplanationResults
29+
30+
31+
class PDPResults(ExplanationResults):
32+
"""
33+
Results class for Partial Dependence Plots
34+
"""
35+
36+
def __init__(self, pdp_graphs):
37+
self.pdp_graphs = pdp_graphs
38+
39+
def as_dataframe(self) -> pd.DataFrame:
40+
"""
41+
Returns
42+
-------
43+
a pd.DataFrame with input values and feature name as
44+
columns and marginal feature outputs as rows
45+
"""
46+
pdp_series_list = []
47+
for pdp_graph in self.pdp_graphs:
48+
inputs = [self._to_plottable(x) for x in pdp_graph.getX()]
49+
outputs = [self._to_plottable(y) for y in pdp_graph.getY()]
50+
pdp_dict = dict(zip(inputs, outputs))
51+
pdp_dict["feature"] = "" + str(pdp_graph.getFeature().getName())
52+
pdp_series = pd.Series(index=inputs + ["feature"], data=pdp_dict)
53+
pdp_series_list.append(pdp_series)
54+
pdp_df = pd.DataFrame(pdp_series_list)
55+
return pdp_df
56+
57+
def as_html(self) -> Styler:
58+
"""
59+
Returns
60+
-------
61+
Style object from the PDP pd.DataFrame (see as_dataframe)
62+
"""
63+
return self.as_dataframe().style
64+
65+
def plot(self, output_name=None, block=True) -> None:
66+
"""
67+
Parameters
68+
----------
69+
output_name: str
70+
name of the output to be plotted
71+
Default to None
72+
block: bool
73+
whether the plotting operation
74+
should be blocking or not
75+
"""
76+
fig, axs = plt.subplots(len(self.pdp_graphs), constrained_layout=True)
77+
p_idx = 0
78+
for pdp_graph in self.pdp_graphs:
79+
if output_name is not None and output_name != str(
80+
pdp_graph.getOutput().getName()
81+
):
82+
continue
83+
fig.suptitle(str(pdp_graph.getOutput().getName()))
84+
pdp_x = []
85+
for i in range(len(pdp_graph.getX())):
86+
pdp_x.append(self._to_plottable(pdp_graph.getX()[i]))
87+
pdp_y = []
88+
for i in range(len(pdp_graph.getY())):
89+
pdp_y.append(self._to_plottable(pdp_graph.getY()[i]))
90+
axs[p_idx].plot(pdp_x, pdp_y)
91+
axs[p_idx].set_title(
92+
str(pdp_graph.getFeature().getName()), loc="left", fontsize="small"
93+
)
94+
axs[p_idx].grid()
95+
p_idx += 1
96+
fig.supylabel("Partial Dependence Plot")
97+
plt.show(block=block)
98+
99+
@staticmethod
100+
def _to_plottable(datum: Value):
101+
plottable = datum.asNumber()
102+
if math.isnan(plottable):
103+
plottable = str(datum.asString())
104+
return plottable
105+
106+
107+
# pylint: disable = too-few-public-methods
108+
class PDPExplainer:
109+
"""
110+
Partial Dependence Plot explainer.
111+
See https://christophm.github.io/interpretable-ml-book/pdp.html
112+
"""
113+
114+
def __init__(self, config=None):
115+
if config is None:
116+
config = pdp.PartialDependencePlotConfig()
117+
self._explainer = pdp.PartialDependencePlotExplainer(config)
118+
119+
def explain(
120+
self, model: PredictionProvider, data: ManyInputsUnionType, num_outputs: int = 1
121+
) -> PDPResults:
122+
"""
123+
Parameters
124+
----------
125+
model: PredictionProvider
126+
the model to explain
127+
data: ManyInputsUnionType
128+
the data used to calculate the PDP
129+
num_outputs: int
130+
the number of outputs to calculate the PDP for
131+
132+
Returns
133+
-------
134+
pdp_results: PDPResults
135+
the partial dependence plots associated to the model outputs
136+
"""
137+
metadata = _PredictionProviderMetadata(many_inputs_convert(data), num_outputs)
138+
pdp_graphs = self._explainer.explainFromMetadata(model, metadata)
139+
return PDPResults(pdp_graphs)
140+
141+
142+
@JImplements(
143+
"org.kie.trustyai.explainability.model.PredictionProviderMetadata", deferred=True
144+
)
145+
class _PredictionProviderMetadata:
146+
"""
147+
Implementation of org.kie.trustyai.explainability.model.PredictionProviderMetadata interface
148+
"""
149+
150+
def __init__(self, data: list, size: int):
151+
"""
152+
Parameters
153+
----------
154+
data: ManyInputsUnionType
155+
the data
156+
size: int
157+
the size of the model output
158+
"""
159+
self.data = PredictionInputsDataDistribution(data)
160+
outputs = []
161+
for _ in range(size):
162+
outputs.append(Output("", Type.UNDEFINED))
163+
self.pred_out = PredictionOutput(outputs)
164+
165+
# pylint: disable = invalid-name
166+
@JOverride
167+
def getDataDistribution(self):
168+
"""
169+
Returns
170+
--------
171+
the underlying data distribution
172+
"""
173+
return self.data
174+
175+
# pylint: disable = invalid-name
176+
@JOverride
177+
def getInputShape(self):
178+
"""
179+
Returns
180+
--------
181+
a PredictionInput from the underlying distribution
182+
"""
183+
return self.data.sample()
184+
185+
# pylint: disable = invalid-name
186+
@JOverride
187+
def getOutputShape(self):
188+
"""
189+
Returns
190+
--------
191+
a PredictionOutput
192+
"""
193+
return self.pred_out

tests/general/test_pdp.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# pylint: disable=import-error, wrong-import-position, wrong-import-order, invalid-name
2+
"""PDP test suite"""
3+
4+
import numpy as np
5+
import pandas as pd
6+
import pytest
7+
from sklearn.datasets import make_classification
8+
from trustyai.explainers import PDPExplainer
9+
from trustyai.model import Model
10+
from trustyai.utils import TestModels
11+
12+
13+
def create_random_df():
14+
X, _ = make_classification(n_samples=5000, n_features=5, n_classes=2,
15+
n_clusters_per_class=2, class_sep=2, flip_y=0, random_state=23)
16+
17+
return pd.DataFrame({
18+
'x1': X[:, 0],
19+
'x2': X[:, 1],
20+
'x3': X[:, 2],
21+
'x4': X[:, 3],
22+
'x5': X[:, 4],
23+
})
24+
25+
26+
def test_pdp_sumskip():
27+
"""Test PDP with sum skip model on random generated data"""
28+
29+
df = create_random_df()
30+
model = TestModels.getSumSkipModel(0)
31+
pdp_explainer = PDPExplainer()
32+
pdp_results = pdp_explainer.explain(model, df)
33+
assert pdp_results is not None
34+
assert pdp_results.as_dataframe() is not None
35+
36+
37+
def test_pdp_sumthreshold():
38+
"""Test PDP with sum threshold model on random generated data"""
39+
40+
df = create_random_df()
41+
model = TestModels.getLinearThresholdModel([0.1, 0.2, 0.3, 0.4, 0.5], 0)
42+
pdp_explainer = PDPExplainer()
43+
pdp_results = pdp_explainer.explain(model, df)
44+
assert pdp_results is not None
45+
assert pdp_results.as_dataframe() is not None
46+
47+
48+
def pdp_plots(block):
49+
"""Test PDP plots"""
50+
np.random.seed(0)
51+
data = pd.DataFrame(np.random.rand(101, 5))
52+
53+
model_weights = np.random.rand(5)
54+
predict_function = lambda x: np.stack([np.dot(x.values, model_weights), 2 * np.dot(x.values, model_weights)], -1)
55+
model = Model(predict_function, dataframe_input=True)
56+
pdp_explainer = PDPExplainer()
57+
explanation = pdp_explainer.explain(model, data)
58+
59+
explanation.plot(block=block)
60+
explanation.plot(block=block, output_name='output-0')
61+
62+
63+
@pytest.mark.block_plots
64+
def test_lime_plots_blocking():
65+
pdp_plots(True)
66+
67+
68+
def test_lime_plots():
69+
pdp_plots(False)

0 commit comments

Comments
 (0)