Skip to content

Commit 1b4329c

Browse files
authored
FAI-887: Generalize ExplanationResults (#115)
* Created unified explainer + saliency result abtract classes, joined plotting syntax * linting and black * fixed doc typo in explanation results
1 parent 5704b99 commit 1b4329c

File tree

10 files changed

+348
-285
lines changed

10 files changed

+348
-285
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
import uuid as _uuid
1010

1111
from trustyai import _default_initializer # pylint: disable=unused-import
12+
from .explanation_results import ExplanationResults
1213
from trustyai.utils._visualisation import (
13-
ExplanationVisualiser,
1414
DEFAULT_STYLE as ds,
1515
DEFAULT_RC_PARAMS as drcp,
1616
)
@@ -41,7 +41,7 @@
4141
CounterfactualConfig = _CounterfactualConfig
4242

4343

44-
class CounterfactualResult(ExplanationVisualiser):
44+
class CounterfactualResult(ExplanationResults):
4545
"""Wraps Counterfactual results. This object is returned by the
4646
:class:`~CounterfactualExplainer`, and provides a variety of methods to visualize and interact
4747
with the results of the counterfactual explanation.
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
"""Generic class for Explanation and Saliency results"""
2+
from abc import ABC, abstractmethod
3+
from typing import Dict
4+
5+
import bokeh.models
6+
import pandas as pd
7+
from bokeh.io import show
8+
from pandas.io.formats.style import Styler
9+
10+
11+
class ExplanationResults(ABC):
12+
"""Abstract class for explanation visualisers"""
13+
14+
@abstractmethod
15+
def as_dataframe(self) -> pd.DataFrame:
16+
"""Display explanation result as a dataframe"""
17+
18+
@abstractmethod
19+
def as_html(self) -> Styler:
20+
"""Visualise the styled dataframe"""
21+
22+
23+
# pylint: disable=too-few-public-methods
24+
class SaliencyResults(ExplanationResults):
25+
"""Abstract class for saliency visualisers"""
26+
27+
@abstractmethod
28+
def saliency_map(self):
29+
"""Return the Saliencies as a dictionary, keyed by output name"""
30+
31+
@abstractmethod
32+
def _matplotlib_plot(self, output_name: str) -> None:
33+
"""Plot the saliencies of a particular output in matplotlib"""
34+
35+
@abstractmethod
36+
def _get_bokeh_plot(self, output_name: str) -> bokeh.models.Plot:
37+
"""Get a bokeh plot visualizing the saliencies of a particular output"""
38+
39+
def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
40+
"""Get a dictionary containing visualizations of the saliencies of all outputs,
41+
keyed by output name"""
42+
return {
43+
output_name: self._get_bokeh_plot(output_name)
44+
for output_name in self.saliency_map().keys()
45+
}
46+
47+
def plot(self, output_name=None, render_bokeh=False) -> None:
48+
"""
49+
Plot the found feature saliencies.
50+
51+
Parameters
52+
----------
53+
output_name : str
54+
(default=None) The name of the output to be explainer. If `None`, all outputs will
55+
be displayed
56+
render_bokeh : bool
57+
(default: false) Whether to render as bokeh (true) or matplotlib (false)
58+
"""
59+
if output_name is None:
60+
for output_name_iterator in self.saliency_map().keys():
61+
if render_bokeh:
62+
show(self._get_bokeh_plot(output_name_iterator))
63+
else:
64+
self._matplotlib_plot(output_name_iterator)
65+
else:
66+
if render_bokeh:
67+
show(self._get_bokeh_plot(output_name))
68+
else:
69+
self._matplotlib_plot(output_name)

src/trustyai/explainers/lime.py

Lines changed: 70 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
44
from typing import Dict
5+
6+
import bokeh.models
57
import matplotlib.pyplot as plt
68
import matplotlib as mpl
79
from bokeh.models import ColumnDataSource, HoverTool
@@ -10,15 +12,14 @@
1012

1113
from trustyai import _default_initializer # pylint: disable=unused-import
1214
from trustyai.utils._visualisation import (
13-
ExplanationVisualiser,
1415
DEFAULT_STYLE as ds,
1516
DEFAULT_RC_PARAMS as drcp,
1617
bold_red_html,
1718
bold_green_html,
1819
output_html,
1920
feature_html,
2021
)
21-
22+
from .explanation_results import SaliencyResults
2223
from trustyai.model import simple_prediction, PredUnionType
2324

2425
from org.kie.trustyai.explainability.local.lime import (
@@ -29,7 +30,6 @@
2930
EncodingParams,
3031
PredictionProvider,
3132
Saliency,
32-
SaliencyResults,
3333
PerturbationContext,
3434
)
3535

@@ -38,17 +38,17 @@
3838
LimeConfig = _LimeConfig
3939

4040

41-
class LimeResults(ExplanationVisualiser):
41+
class LimeResults(SaliencyResults):
4242
"""Wraps LIME results. This object is returned by the :class:`~LimeExplainer`,
4343
and provides a variety of methods to visualize and interact with the explanation.
4444
"""
4545

4646
def __init__(self, saliencyResults: SaliencyResults):
4747
"""Constructor method. This is called internally, and shouldn't ever need to be used
4848
manually."""
49-
self._saliency_results = saliencyResults
49+
self._java_saliency_results = saliencyResults
5050

51-
def map(self) -> Dict[str, Saliency]:
51+
def saliency_map(self) -> Dict[str, Saliency]:
5252
"""
5353
Return a dictionary of found saliencies.
5454
@@ -59,7 +59,7 @@ def map(self) -> Dict[str, Saliency]:
5959
"""
6060
return {
6161
entry.getKey(): entry.getValue()
62-
for entry in self._saliency_results.saliencies.entrySet()
62+
for entry in self._java_saliency_results.saliencies.entrySet()
6363
}
6464

6565
def as_dataframe(self) -> pd.DataFrame:
@@ -77,11 +77,11 @@ def as_dataframe(self) -> pd.DataFrame:
7777
* ``${output_name}_value``: The original value of each feature.
7878
* ``${output_name}_confidence``: The confidence of the reported saliency.
7979
"""
80-
outputs = self.map().keys()
80+
outputs = self.saliency_map().keys()
8181

8282
data = {}
8383
for output in outputs:
84-
pfis = self.map().get(output).getPerFeatureImportance()
84+
pfis = self.saliency_map().get(output).getPerFeatureImportance()
8585
data[f"{output}_features"] = [
8686
f"{pfi.getFeature().getName()}" for pfi in pfis
8787
]
@@ -106,12 +106,12 @@ def as_html(self) -> pd.io.formats.style.Styler:
106106
"""
107107
return self.as_dataframe().style
108108

109-
def plot(self, decision: str) -> None:
109+
def _matplotlib_plot(self, output_name: str) -> None:
110110
"""Plot the LIME saliencies."""
111111
with mpl.rc_context(drcp):
112112
dictionary = {}
113113
for feature_importance in (
114-
self.map().get(decision).getPerFeatureImportance()
114+
self.saliency_map().get(output_name).getPerFeatureImportance()
115115
):
116116
dictionary[
117117
feature_importance.getFeature().name
@@ -123,7 +123,7 @@ def plot(self, decision: str) -> None:
123123
else ds["positive_primary_colour"]
124124
for i in dictionary.values()
125125
]
126-
plt.title(f"LIME explanation of {decision}")
126+
plt.title(f"LIME explanation of {output_name}")
127127
plt.barh(
128128
range(len(dictionary)),
129129
dictionary.values(),
@@ -134,64 +134,65 @@ def plot(self, decision: str) -> None:
134134
plt.tight_layout()
135135
plt.show()
136136

137-
def _get_bokeh_plot_dict(self):
138-
plot_dict = {}
139-
for output_name, value in self.map().items():
140-
lime_data_source = pd.DataFrame(
141-
[
142-
{
143-
"feature": str(pfi.getFeature().getName()),
144-
"saliency": pfi.getScore(),
145-
}
146-
for pfi in value.getPerFeatureImportance()
147-
]
148-
)
149-
lime_data_source["color"] = lime_data_source["saliency"].apply(
150-
lambda x: ds["positive_primary_colour"]
151-
if x >= 0
152-
else ds["negative_primary_colour"]
153-
)
154-
lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply(
155-
lambda x: (bold_green_html if x >= 0 else bold_red_html)(
156-
"{:.2f}".format(x)
157-
)
158-
)
137+
def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
138+
lime_data_source = pd.DataFrame(
139+
[
140+
{
141+
"feature": str(pfi.getFeature().getName()),
142+
"saliency": pfi.getScore(),
143+
}
144+
for pfi in self.saliency_map()[output_name].getPerFeatureImportance()
145+
]
146+
)
147+
lime_data_source["color"] = lime_data_source["saliency"].apply(
148+
lambda x: ds["positive_primary_colour"]
149+
if x >= 0
150+
else ds["negative_primary_colour"]
151+
)
152+
lime_data_source["saliency_colored"] = lime_data_source["saliency"].apply(
153+
lambda x: (bold_green_html if x >= 0 else bold_red_html)("{:.2f}".format(x))
154+
)
159155

160-
lime_data_source["color_faded"] = lime_data_source["saliency"].apply(
161-
lambda x: ds["positive_primary_colour_faded"]
162-
if x >= 0
163-
else ds["negative_primary_colour_faded"]
164-
)
165-
source = ColumnDataSource(lime_data_source)
166-
htool = HoverTool(
167-
names=["bars"],
168-
tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format(
169-
feature_html("@feature"), output_html(output_name)
170-
),
171-
)
172-
bokeh_plot = figure(
173-
sizing_mode="stretch_both",
174-
title="Lime Feature Importances",
175-
y_range=lime_data_source["feature"],
176-
tools=[htool],
177-
)
178-
bokeh_plot.hbar(
179-
y="feature",
180-
left=0,
181-
right="saliency",
182-
fill_color="color_faded",
183-
line_color="color",
184-
hover_color="color",
185-
color="color",
186-
height=0.75,
187-
name="bars",
188-
source=source,
189-
)
190-
bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000")
191-
bokeh_plot.xaxis.axis_label = "Saliency Value"
192-
bokeh_plot.yaxis.axis_label = "Feature"
193-
plot_dict[output_name] = bokeh_plot
194-
return plot_dict
156+
lime_data_source["color_faded"] = lime_data_source["saliency"].apply(
157+
lambda x: ds["positive_primary_colour_faded"]
158+
if x >= 0
159+
else ds["negative_primary_colour_faded"]
160+
)
161+
source = ColumnDataSource(lime_data_source)
162+
htool = HoverTool(
163+
names=["bars"],
164+
tooltips="<h3>LIME</h3> {} saliency to {}: @saliency_colored".format(
165+
feature_html("@feature"), output_html(output_name)
166+
),
167+
)
168+
bokeh_plot = figure(
169+
sizing_mode="stretch_both",
170+
title="Lime Feature Importances",
171+
y_range=lime_data_source["feature"],
172+
tools=[htool],
173+
)
174+
bokeh_plot.hbar(
175+
y="feature",
176+
left=0,
177+
right="saliency",
178+
fill_color="color_faded",
179+
line_color="color",
180+
hover_color="color",
181+
color="color",
182+
height=0.75,
183+
name="bars",
184+
source=source,
185+
)
186+
bokeh_plot.line([0, 0], [0, len(lime_data_source)], color="#000")
187+
bokeh_plot.xaxis.axis_label = "Saliency Value"
188+
bokeh_plot.yaxis.axis_label = "Feature"
189+
return bokeh_plot
190+
191+
def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
192+
return {
193+
output_name: self._get_bokeh_plot(output_name)
194+
for output_name in self.saliency_map().keys()
195+
}
195196

196197

197198
class LimeExplainer:

0 commit comments

Comments
 (0)