2
2
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3
3
# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
4
4
from typing import Dict
5
+
6
+ import bokeh .models
5
7
import matplotlib .pyplot as plt
6
8
import matplotlib as mpl
7
9
from bokeh .models import ColumnDataSource , HoverTool
10
12
11
13
from trustyai import _default_initializer # pylint: disable=unused-import
12
14
from trustyai .utils ._visualisation import (
13
- ExplanationVisualiser ,
14
15
DEFAULT_STYLE as ds ,
15
16
DEFAULT_RC_PARAMS as drcp ,
16
17
bold_red_html ,
17
18
bold_green_html ,
18
19
output_html ,
19
20
feature_html ,
20
21
)
21
-
22
+ from . explanation_results import SaliencyResults
22
23
from trustyai .model import simple_prediction , PredUnionType
23
24
24
25
from org .kie .trustyai .explainability .local .lime import (
29
30
EncodingParams ,
30
31
PredictionProvider ,
31
32
Saliency ,
32
- SaliencyResults ,
33
33
PerturbationContext ,
34
34
)
35
35
38
38
LimeConfig = _LimeConfig
39
39
40
40
41
- class LimeResults (ExplanationVisualiser ):
41
+ class LimeResults (SaliencyResults ):
42
42
"""Wraps LIME results. This object is returned by the :class:`~LimeExplainer`,
43
43
and provides a variety of methods to visualize and interact with the explanation.
44
44
"""
45
45
46
46
def __init__ (self , saliencyResults : SaliencyResults ):
47
47
"""Constructor method. This is called internally, and shouldn't ever need to be used
48
48
manually."""
49
- self ._saliency_results = saliencyResults
49
+ self ._java_saliency_results = saliencyResults
50
50
51
- def map (self ) -> Dict [str , Saliency ]:
51
+ def saliency_map (self ) -> Dict [str , Saliency ]:
52
52
"""
53
53
Return a dictionary of found saliencies.
54
54
@@ -59,7 +59,7 @@ def map(self) -> Dict[str, Saliency]:
59
59
"""
60
60
return {
61
61
entry .getKey (): entry .getValue ()
62
- for entry in self ._saliency_results .saliencies .entrySet ()
62
+ for entry in self ._java_saliency_results .saliencies .entrySet ()
63
63
}
64
64
65
65
def as_dataframe (self ) -> pd .DataFrame :
@@ -77,11 +77,11 @@ def as_dataframe(self) -> pd.DataFrame:
77
77
* ``${output_name}_value``: The original value of each feature.
78
78
* ``${output_name}_confidence``: The confidence of the reported saliency.
79
79
"""
80
- outputs = self .map ().keys ()
80
+ outputs = self .saliency_map ().keys ()
81
81
82
82
data = {}
83
83
for output in outputs :
84
- pfis = self .map ().get (output ).getPerFeatureImportance ()
84
+ pfis = self .saliency_map ().get (output ).getPerFeatureImportance ()
85
85
data [f"{ output } _features" ] = [
86
86
f"{ pfi .getFeature ().getName ()} " for pfi in pfis
87
87
]
@@ -106,12 +106,12 @@ def as_html(self) -> pd.io.formats.style.Styler:
106
106
"""
107
107
return self .as_dataframe ().style
108
108
109
- def plot (self , decision : str ) -> None :
109
+ def _matplotlib_plot (self , output_name : str ) -> None :
110
110
"""Plot the LIME saliencies."""
111
111
with mpl .rc_context (drcp ):
112
112
dictionary = {}
113
113
for feature_importance in (
114
- self .map ().get (decision ).getPerFeatureImportance ()
114
+ self .saliency_map ().get (output_name ).getPerFeatureImportance ()
115
115
):
116
116
dictionary [
117
117
feature_importance .getFeature ().name
@@ -123,7 +123,7 @@ def plot(self, decision: str) -> None:
123
123
else ds ["positive_primary_colour" ]
124
124
for i in dictionary .values ()
125
125
]
126
- plt .title (f"LIME explanation of { decision } " )
126
+ plt .title (f"LIME explanation of { output_name } " )
127
127
plt .barh (
128
128
range (len (dictionary )),
129
129
dictionary .values (),
@@ -134,64 +134,65 @@ def plot(self, decision: str) -> None:
134
134
plt .tight_layout ()
135
135
plt .show ()
136
136
137
- def _get_bokeh_plot_dict (self ):
138
- plot_dict = {}
139
- for output_name , value in self .map ().items ():
140
- lime_data_source = pd .DataFrame (
141
- [
142
- {
143
- "feature" : str (pfi .getFeature ().getName ()),
144
- "saliency" : pfi .getScore (),
145
- }
146
- for pfi in value .getPerFeatureImportance ()
147
- ]
148
- )
149
- lime_data_source ["color" ] = lime_data_source ["saliency" ].apply (
150
- lambda x : ds ["positive_primary_colour" ]
151
- if x >= 0
152
- else ds ["negative_primary_colour" ]
153
- )
154
- lime_data_source ["saliency_colored" ] = lime_data_source ["saliency" ].apply (
155
- lambda x : (bold_green_html if x >= 0 else bold_red_html )(
156
- "{:.2f}" .format (x )
157
- )
158
- )
137
+ def _get_bokeh_plot (self , output_name ) -> bokeh .models .Plot :
138
+ lime_data_source = pd .DataFrame (
139
+ [
140
+ {
141
+ "feature" : str (pfi .getFeature ().getName ()),
142
+ "saliency" : pfi .getScore (),
143
+ }
144
+ for pfi in self .saliency_map ()[output_name ].getPerFeatureImportance ()
145
+ ]
146
+ )
147
+ lime_data_source ["color" ] = lime_data_source ["saliency" ].apply (
148
+ lambda x : ds ["positive_primary_colour" ]
149
+ if x >= 0
150
+ else ds ["negative_primary_colour" ]
151
+ )
152
+ lime_data_source ["saliency_colored" ] = lime_data_source ["saliency" ].apply (
153
+ lambda x : (bold_green_html if x >= 0 else bold_red_html )("{:.2f}" .format (x ))
154
+ )
159
155
160
- lime_data_source ["color_faded" ] = lime_data_source ["saliency" ].apply (
161
- lambda x : ds ["positive_primary_colour_faded" ]
162
- if x >= 0
163
- else ds ["negative_primary_colour_faded" ]
164
- )
165
- source = ColumnDataSource (lime_data_source )
166
- htool = HoverTool (
167
- names = ["bars" ],
168
- tooltips = "<h3>LIME</h3> {} saliency to {}: @saliency_colored" .format (
169
- feature_html ("@feature" ), output_html (output_name )
170
- ),
171
- )
172
- bokeh_plot = figure (
173
- sizing_mode = "stretch_both" ,
174
- title = "Lime Feature Importances" ,
175
- y_range = lime_data_source ["feature" ],
176
- tools = [htool ],
177
- )
178
- bokeh_plot .hbar (
179
- y = "feature" ,
180
- left = 0 ,
181
- right = "saliency" ,
182
- fill_color = "color_faded" ,
183
- line_color = "color" ,
184
- hover_color = "color" ,
185
- color = "color" ,
186
- height = 0.75 ,
187
- name = "bars" ,
188
- source = source ,
189
- )
190
- bokeh_plot .line ([0 , 0 ], [0 , len (lime_data_source )], color = "#000" )
191
- bokeh_plot .xaxis .axis_label = "Saliency Value"
192
- bokeh_plot .yaxis .axis_label = "Feature"
193
- plot_dict [output_name ] = bokeh_plot
194
- return plot_dict
156
+ lime_data_source ["color_faded" ] = lime_data_source ["saliency" ].apply (
157
+ lambda x : ds ["positive_primary_colour_faded" ]
158
+ if x >= 0
159
+ else ds ["negative_primary_colour_faded" ]
160
+ )
161
+ source = ColumnDataSource (lime_data_source )
162
+ htool = HoverTool (
163
+ names = ["bars" ],
164
+ tooltips = "<h3>LIME</h3> {} saliency to {}: @saliency_colored" .format (
165
+ feature_html ("@feature" ), output_html (output_name )
166
+ ),
167
+ )
168
+ bokeh_plot = figure (
169
+ sizing_mode = "stretch_both" ,
170
+ title = "Lime Feature Importances" ,
171
+ y_range = lime_data_source ["feature" ],
172
+ tools = [htool ],
173
+ )
174
+ bokeh_plot .hbar (
175
+ y = "feature" ,
176
+ left = 0 ,
177
+ right = "saliency" ,
178
+ fill_color = "color_faded" ,
179
+ line_color = "color" ,
180
+ hover_color = "color" ,
181
+ color = "color" ,
182
+ height = 0.75 ,
183
+ name = "bars" ,
184
+ source = source ,
185
+ )
186
+ bokeh_plot .line ([0 , 0 ], [0 , len (lime_data_source )], color = "#000" )
187
+ bokeh_plot .xaxis .axis_label = "Saliency Value"
188
+ bokeh_plot .yaxis .axis_label = "Feature"
189
+ return bokeh_plot
190
+
191
+ def _get_bokeh_plot_dict (self ) -> Dict [str , bokeh .models .Plot ]:
192
+ return {
193
+ output_name : self ._get_bokeh_plot (output_name )
194
+ for output_name in self .saliency_map ().keys ()
195
+ }
195
196
196
197
197
198
class LimeExplainer :
0 commit comments