Skip to content

Commit 03e0452

Browse files
authored
Unified as_df and as_html methods between SHAP and LIME (#131)
1 parent f6366c4 commit 03e0452

File tree

4 files changed

+122
-79
lines changed

4 files changed

+122
-79
lines changed

src/trustyai/explainers/lime.py

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,11 @@
66
import bokeh.models
77
import matplotlib.pyplot as plt
88
import matplotlib as mpl
9+
import numpy as np
910
from bokeh.models import ColumnDataSource, HoverTool
1011
from bokeh.plotting import figure
1112
import pandas as pd
13+
from matplotlib.colors import LinearSegmentedColormap
1214

1315
from trustyai import _default_initializer # pylint: disable=unused-import
1416
from trustyai.utils._visualisation import (
@@ -77,42 +79,62 @@ def as_dataframe(self) -> pd.DataFrame:
7779
Returns
7880
-------
7981
pandas.DataFrame
80-
DataFrame containing the results of the LIME explanation. For each model output, the
81-
table will contain the following columns:
82+
Dictionary of DataFrames, keyed by output name, containing the results of the LIME
83+
explanation. For each model output, the table will contain the following columns:
84+
85+
* ``Feature``: The name of the feature
86+
* ``Value``: The value of the feature for this particular input.
87+
* ``Saliency``: The importance of this feature to the output.
88+
* ``Confidence``: The confidence of this explanation as returned by the explainer.
8289
83-
* ``${output_name}_features``: The names of each input feature.
84-
* ``${output_name}_score``: The LIME saliency of this feature.
85-
* ``${output_name}_value``: The original value of each feature.
86-
* ``${output_name}_confidence``: The confidence of the reported saliency.
8790
"""
8891
outputs = self.saliency_map().keys()
8992

9093
data = {}
9194
for output in outputs:
92-
pfis = self.saliency_map().get(output).getPerFeatureImportance()
93-
data[f"{output}_features"] = [
94-
f"{pfi.getFeature().getName()}" for pfi in pfis
95-
]
96-
data[f"{output}_score"] = [pfi.getScore() for pfi in pfis]
97-
data[f"{output}_value"] = [
98-
pfi.getFeature().getValue().as_number() for pfi in pfis
99-
]
100-
data[f"{output}_confidence"] = [pfi.getConfidence() for pfi in pfis]
101-
102-
return pd.DataFrame.from_dict(data)
95+
output_rows = []
96+
for pfi in self.saliency_map().get(output).getPerFeatureImportance():
97+
output_rows.append(
98+
{
99+
"Feature": str(pfi.getFeature().getName().toString()),
100+
"Value": pfi.getFeature().getValue().getUnderlyingObject(),
101+
"Saliency": pfi.getScore(),
102+
"Confidence": pfi.getConfidence(),
103+
}
104+
)
105+
data[output] = pd.DataFrame(output_rows)
106+
return data
103107

104108
def as_html(self) -> pd.io.formats.style.Styler:
105109
"""
106-
Return the LIME result as a Pandas Styler object.
110+
Return the LIME results as Pandas Styler objects.
107111
108112
Returns
109113
-------
110-
pandas.Styler
111-
Styler containing the results of the LIME explanation, in the same
112-
schema as in :func:`as_dataframe`. Currently, no default styles are applied
113-
in this particular function, making it equivalent to :code:`self.as_dataframe().style`.
114+
Dict[str, pandas.Styler]
115+
Dictionary of stylers keyed by output name. Each styler containing the results of the
116+
LIME explanation for that particular output, in the same
117+
schema as in :func:`as_dataframe`. This will:
118+
119+
* Color each ``Saliency`` based on how their magnitude.
114120
"""
115-
return self.as_dataframe().style
121+
122+
htmls = {}
123+
for k, df in self.as_dataframe().items():
124+
htmls[k] = df.style.background_gradient(
125+
LinearSegmentedColormap.from_list(
126+
name="rwg",
127+
colors=[
128+
ds["negative_primary_colour"],
129+
ds["neutral_primary_colour"],
130+
ds["positive_primary_colour"],
131+
],
132+
),
133+
subset="Saliency",
134+
vmin=-1 * max(np.abs(df["Saliency"])),
135+
vmax=max(np.abs(df["Saliency"])),
136+
)
137+
return htmls
116138

117139
def _matplotlib_plot(self, output_name: str, block=True) -> None:
118140
"""Plot the LIME saliencies."""

src/trustyai/explainers/shap.py

Lines changed: 35 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -101,44 +101,27 @@ def _saliency_to_dataframe(self, saliency, output_name):
101101
],
102102
0,
103103
).tolist()
104-
feature_values = [
105-
pfi.getFeature().getValue().asNumber()
106-
for pfi in saliency.getPerFeatureImportance()[:-1]
107-
]
108-
shap_values = [
109-
pfi.getScore() for pfi in saliency.getPerFeatureImportance()[:-1]
110-
]
111-
feature_names = [
112-
str(pfi.getFeature().getName())
113-
for pfi in saliency.getPerFeatureImportance()[:-1]
114-
]
115-
116-
columns = ["Mean Background Value", "Feature Value", "SHAP Value"]
117-
visualizer_data_frame = pd.DataFrame(
118-
[background_mean_feature_values, feature_values, shap_values],
119-
index=columns,
120-
columns=feature_names,
121-
).T
122-
fnull = self.get_fnull()[output_name]
123104

124-
return (
125-
pd.concat(
126-
[
127-
pd.DataFrame(
128-
[["-", "-", fnull]], index=["Background"], columns=columns
129-
),
130-
visualizer_data_frame,
131-
pd.DataFrame(
132-
[[fnull, sum(shap_values) + fnull, sum(shap_values) + fnull]],
133-
index=["Prediction"],
134-
columns=columns,
135-
),
136-
]
137-
),
138-
feature_names,
139-
shap_values,
140-
background_mean_feature_values,
141-
)
105+
data_rows = []
106+
for i, pfi in enumerate(saliency.getPerFeatureImportance()[:-1]):
107+
data_rows.append(
108+
{
109+
"Feature": str(pfi.getFeature().getName().toString()),
110+
"Value": pfi.getFeature().getValue().getUnderlyingObject(),
111+
"Mean Background Value": background_mean_feature_values[i],
112+
"SHAP Value": pfi.getScore(),
113+
"Confidence": pfi.getConfidence(),
114+
}
115+
)
116+
117+
fnull = {
118+
"Feature": "Background",
119+
"Value": None,
120+
"Mean Background Value": None,
121+
"SHAP Value": self.get_fnull()[output_name],
122+
}
123+
124+
return pd.DataFrame([fnull] + data_rows)
142125

143126
def as_dataframe(self) -> Dict[str, pd.DataFrame]:
144127
"""
@@ -148,16 +131,18 @@ def as_dataframe(self) -> Dict[str, pd.DataFrame]:
148131
-------
149132
Dict[str, pandas.DataFrame]
150133
Dictionary of DataFrames, keyed by output name, containing the results of the SHAP
151-
explanation. For each model output, the table will contain the following columns,
152-
indexed by feature name:
134+
explanation. For each model output, the table will contain the following columns:
153135
154-
* ``Mean Background Value``: The mean value this feature took in the background
136+
* ``Feature``: The name of the feature
155137
* ``Feature Value``: The value of the feature for this particular input.
138+
* ``Mean Background Value``: The mean value this feature took in the background
156139
* ``SHAP Value``: The found SHAP value of this feature.
140+
* ``Confidence``: The confidence of this explanation as returned by the explainer.
141+
157142
"""
158143
df_dict = {}
159144
for output_name, saliency in self.saliency_map().items():
160-
df_dict[output_name] = self._saliency_to_dataframe(saliency, output_name)[0]
145+
df_dict[output_name] = self._saliency_to_dataframe(saliency, output_name)
161146
return df_dict
162147

163148
def as_html(self) -> Dict[str, pd.io.formats.style.Styler]:
@@ -179,23 +164,21 @@ def as_html(self) -> Dict[str, pd.io.formats.style.Styler]:
179164
def _color_feature_values(feature_values, background_vals):
180165
"""Internal function for the dataframe visualization"""
181166
formats = []
182-
for i, feature_value in enumerate(feature_values[1:-1]):
167+
for i, feature_value in enumerate(feature_values[1:]):
183168
if feature_value < background_vals[i]:
184169
formats.append(f"background-color:{ds['negative_primary_colour']}")
185170
elif feature_value > background_vals[i]:
186171
formats.append(f"background-color:{ds['positive_primary_colour']}")
187172
else:
188173
formats.append(None)
189-
return [None] + formats + [None]
174+
return [None] + formats
190175

191176
df_dict = {}
192-
for i, (output_name, saliency) in enumerate(self.saliency_map().items()):
193-
(
194-
df,
195-
feature_names,
196-
shap_values,
197-
background_mean_feature_values,
198-
) = self._saliency_to_dataframe(saliency, i)
177+
for output_name, saliency in self.saliency_map().items():
178+
df = self._saliency_to_dataframe(saliency, output_name)
179+
shap_values = df["SHAP Value"].values[1:]
180+
background_mean_feature_values = df["Mean Background Value"].values[1:]
181+
199182
style = df.style.background_gradient(
200183
LinearSegmentedColormap.from_list(
201184
name="rwg",
@@ -205,15 +188,15 @@ def _color_feature_values(feature_values, background_vals):
205188
ds["positive_primary_colour"],
206189
],
207190
),
208-
subset=(slice(feature_names[0], feature_names[-1]), "SHAP Value"),
191+
subset=(slice(1, None), "SHAP Value"),
209192
vmin=-1 * max(np.abs(shap_values)),
210193
vmax=max(np.abs(shap_values)),
211194
)
212195
style.set_caption(f"Explanation of {output_name}")
213196
df_dict[output_name] = style.apply(
214197
_color_feature_values,
215198
background_vals=background_mean_feature_values,
216-
subset="Feature Value",
199+
subset="Value",
217200
axis=0,
218201
)
219202
return df_dict

tests/general/test_limeexplainer.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,16 +118,24 @@ def test_lime_plots():
118118

119119
def test_lime_v2():
120120
np.random.seed(0)
121-
data = pd.DataFrame(np.random.rand(1, 5))
121+
data = pd.DataFrame(np.random.rand(1, 5)).values
122+
122123
model_weights = np.random.rand(5)
123-
predict_function = lambda x: np.dot(x.values, model_weights)
124+
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
125+
model = Model(predict_function)
124126

125-
model = Model(predict_function, dataframe_input=True)
126127
explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False)
127128
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
128-
for score in explanation.as_dataframe()["output-0_score"]:
129+
130+
for score in explanation.as_dataframe()["output-0"]['Saliency']:
129131
assert score != 0
130132

133+
for out_name, df in explanation.as_dataframe().items():
134+
assert "Feature" in df
135+
assert "output" in out_name
136+
assert all([x in str(df) for x in "01234"])
137+
138+
131139
def test_impact_score():
132140
np.random.seed(0)
133141
data = pd.DataFrame(np.random.rand(1, 5))
@@ -143,3 +151,17 @@ def test_impact_score():
143151
impact = ExplainabilityMetrics.impactScore(model, pred, top_features_t)
144152
assert impact > 0
145153
return impact
154+
155+
156+
def test_lime_as_html():
157+
np.random.seed(0)
158+
data = np.random.rand(1, 5)
159+
160+
model_weights = np.random.rand(5)
161+
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
162+
163+
model = Model(predict_function, disable_arrow=True)
164+
165+
explainer = LimeExplainer()
166+
explainer.explain(inputs=data, outputs=model(data), model=model)
167+
assert True

tests/general/test_shap.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,19 @@ def test_shap_as_df():
9696
assert "Mean Background Value" in df
9797
assert "output" in out_name
9898
assert all([x in str(df) for x in "01234"])
99+
100+
101+
def test_shap_as_html():
102+
np.random.seed(0)
103+
data = pd.DataFrame(np.random.rand(101, 5))
104+
background = data.iloc[:100].values
105+
to_explain = data.iloc[100:101].values
106+
107+
model_weights = np.random.rand(5)
108+
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
109+
110+
model = Model(predict_function, disable_arrow=True)
111+
112+
shap_explainer = SHAPExplainer(background=background)
113+
explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model)
114+
assert True

0 commit comments

Comments
 (0)