Skip to content

Commit c3a67da

Browse files
authored
FAI-825: Add feature and output name specification to models (#130)
* Added feature and output name specification to models, all python models will now map these names * Unified as_df and as_html methods between SHAP and LIME * linting, black, and merging
1 parent 03e0452 commit c3a67da

File tree

8 files changed

+162
-44
lines changed

8 files changed

+162
-44
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,9 +224,15 @@ def explain(
224224
:class:`~CounterfactualResult`
225225
Object containing the results of the counterfactual explanation.
226226
"""
227+
feature_names = model.feature_names if isinstance(model, Model) else None
228+
output_names = model.output_names if isinstance(model, Model) else None
227229
_prediction = counterfactual_prediction(
228-
input_features=one_input_convert(inputs, feature_domains=feature_domains),
230+
input_features=one_input_convert(
231+
inputs, feature_names=feature_names, feature_domains=feature_domains
232+
),
229233
outputs=goal,
234+
feature_names=feature_names,
235+
output_names=output_names,
230236
data_distribution=data_distribution,
231237
uuid=uuid,
232238
timeout=timeout,

src/trustyai/explainers/lime.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def _matplotlib_plot(self, output_name: str, block=True) -> None:
153153
else ds["positive_primary_colour"]
154154
for i in dictionary.values()
155155
]
156-
plt.title(f"LIME explanation of {output_name}")
156+
plt.title(f"LIME: Feature Importances to {output_name}")
157157
plt.barh(
158158
range(len(dictionary)),
159159
dictionary.values(),
@@ -306,7 +306,9 @@ def explain(
306306
:class:`~LimeResults`
307307
Object containing the results of the LIME explanation.
308308
"""
309-
_prediction = simple_prediction(inputs, outputs)
309+
feature_names = model.feature_names if isinstance(model, Model) else None
310+
output_names = model.output_names if isinstance(model, Model) else None
311+
_prediction = simple_prediction(inputs, outputs, feature_names, output_names)
310312

311313
with Model.ArrowTransmission(model, inputs):
312314
return LimeResults(self._explainer.explainAsync(_prediction, model).get())

src/trustyai/explainers/shap.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ def _matplotlib_plot(self, output_name, block=True) -> None:
254254
plt.xticks(np.arange(len(feature_names)), feature_names)
255255
plt.ylabel(self.saliency_map()[output_name].getOutput().getName())
256256
plt.xlabel("Feature SHAP Value")
257-
plt.title(f"Explanation of {output_name}")
257+
plt.title(f"SHAP: Feature Contributions to {output_name}")
258258
plt.show(block=block)
259259

260260
def _get_bokeh_plot(self, output_name):
@@ -424,7 +424,9 @@ def __init__(self, datapoints: ManyInputsUnionType, feature_domains=None, seed=0
424424
seed : int
425425
The random seed to use in the sampling/generation method
426426
"""
427-
self.datapoints = many_inputs_convert(datapoints, feature_domains)
427+
self.datapoints = many_inputs_convert(
428+
datapoints, feature_domains=feature_domains
429+
)
428430
self.feature_domains = feature_domains
429431
self.seed = 0
430432
self._jrandom = Random()
@@ -620,21 +622,18 @@ def __init__(
620622
link_type = _ShapConfig.LinkType.IDENTITY
621623
self._jrandom = Random()
622624
self._jrandom.setSeed(kwargs.get("seed", 0))
623-
self.background = many_inputs_convert(background)
625+
self._raw_background = background
624626
perturbation_context = PerturbationContext(self._jrandom, 0)
625627

626628
self._configbuilder = (
627629
_ShapConfig.builder()
628630
.withLink(link_type)
629631
.withBatchSize(kwargs.get("batch_size", 20))
630632
.withPC(perturbation_context)
631-
.withBackground(self.background)
632633
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
633634
)
634635
if kwargs.get("samples") is not None:
635636
self._configbuilder.withNSamples(JInt(kwargs["samples"]))
636-
self._config = self._configbuilder.build()
637-
self._explainer = _ShapKernelExplainer(self._config)
638637

639638
@data_conversion_docstring("one_input", "one_output")
640639
def explain(
@@ -660,9 +659,14 @@ def explain(
660659
:class:`~SHAPResults`
661660
Object containing the results of the SHAP explanation.
662661
"""
663-
_prediction = simple_prediction(inputs, outputs)
664662

663+
feature_names = model.feature_names if isinstance(model, Model) else None
664+
output_names = model.output_names if isinstance(model, Model) else None
665+
_prediction = simple_prediction(inputs, outputs, feature_names, output_names)
666+
_background = many_inputs_convert(self._raw_background, feature_names)
667+
config = self._configbuilder.withBackground(_background).build()
668+
explainer = _ShapKernelExplainer(config)
665669
with Model.ArrowTransmission(model, inputs):
666670
return SHAPResults(
667-
self._explainer.explainAsync(_prediction, model).get(), self.background
671+
explainer.explainAsync(_prediction, model).get(), _background
668672
)

src/trustyai/model/__init__.py

Lines changed: 57 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -294,9 +294,7 @@ class Model:
294294
predictive model to interface with the TrustyAI Java library.
295295
"""
296296

297-
def __init__(
298-
self, predict_fun, dataframe_input=False, output_names=None, disable_arrow=False
299-
):
297+
def __init__(self, predict_fun, **kwargs):
300298
"""
301299
Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class.
302300
@@ -306,20 +304,26 @@ def __init__(
306304
A function that takes in a Numpy array or Pandas DataFrame as input and outputs a
307305
Pandas DataFrame or Numpy array. In general, the ``model.predict`` functions of
308306
sklearn-style models meet this requirement.
309-
dataframe_input: bool
310-
Whether `predict_fun` expects a :class:`pandas.DataFrame` as input.
311-
output_names : List[String]:
312-
If the model outputs a numpy array, you can specify the names of the model outputs
313-
here.
314-
disable_arrow: bool
315-
If true, Apache Arrow will not be used to accelerate data transfer between Java
316-
and Python. If false, Arrow will be automatically used in situations where it is
317-
advantageous to do so.
307+
308+
Keyword Arguments:
309+
* dataframe_input: bool
310+
(default= ``False``) Whether `predict_fun` expects a :class:`pandas.DataFrame`
311+
as input.
312+
* feature_names : List[String]:
313+
(default= ``None`) If the model receives a non-pandas input, you can specify the
314+
names of the model input features here, with the ith element of the list
315+
corresponding to the name of the ith feature.
316+
* output_names : List[String]:
317+
(default= ``None`) If the model outputs a non-pandas object, you can specify the
318+
names of the model outputs here, with the ith element of the list corresponding to
319+
the name of the ith output.
320+
* disable_arrow: bool
321+
(default= ``False`) If true, Apache Arrow will not be used to accelerate data
322+
transfer between Java and Python. If false, Arrow will be automatically used in
323+
situations where it is advantageous to do so.
318324
"""
319-
self.disable_arrow = disable_arrow
320325
self.predict_fun = predict_fun
321-
self.output_names = output_names
322-
self.dataframe_input = dataframe_input
326+
self.kwargs = kwargs
323327

324328
self.prediction_provider_arrow = None
325329
self.prediction_provider_normal = None
@@ -328,6 +332,26 @@ def __init__(
328332
# set model to use non-arrow by default, as this requires no dataset information
329333
self._set_nonarrow()
330334

335+
@property
336+
def dataframe_input(self):
337+
"""Get dataframe_input kwarg value"""
338+
return self.kwargs.get("dataframe_input")
339+
340+
@property
341+
def feature_names(self):
342+
"""Get feature_names kwarg value"""
343+
return self.kwargs.get("feature_names")
344+
345+
@property
346+
def output_names(self):
347+
"""Get output_names kwarg value"""
348+
return self.kwargs.get("output_names")
349+
350+
@property
351+
def disable_arrow(self):
352+
"""Get disable_arrow kwarg value"""
353+
return self.kwargs.get("disable_arrow")
354+
331355
def _set_arrow(self, paradigm_input: PredictionInput):
332356
"""
333357
Ready the model for arrow-based prediction communication.
@@ -825,7 +849,10 @@ def feature(
825849
# pylint: disable=line-too-long
826850
@data_conversion_docstring("one_input", "one_output")
827851
def simple_prediction(
828-
input_features: OneInputUnionType, outputs: OneOutputUnionType
852+
input_features: OneInputUnionType,
853+
outputs: OneOutputUnionType,
854+
feature_names: Optional[List[str]] = None,
855+
output_names: Optional[List[str]] = None,
829856
) -> SimplePrediction:
830857
"""Wrap features and outputs into a SimplePrediction. Given a list of features and outputs,
831858
this function will bundle them into Prediction objects for use with the LIME and SHAP
@@ -838,10 +865,15 @@ def simple_prediction(
838865
outputs : {}
839866
The desired model outputs to be searched for in the counterfactual explanation.
840867
These can take the form of a: {}
868+
feature_names: Optional[List[str]]
869+
The names of the features, in the case where the feature object does not contain them
870+
output_names: Optional[List[str]]
871+
The names of the outputs, in the case where the outputobject does not contain them
841872
"""
842873

843874
return SimplePrediction(
844-
one_input_convert(input_features), one_output_convert(outputs)
875+
one_input_convert(input_features, feature_names),
876+
one_output_convert(outputs, output_names),
845877
)
846878

847879

@@ -850,6 +882,8 @@ def simple_prediction(
850882
def counterfactual_prediction(
851883
input_features: OneInputUnionType,
852884
outputs: OneOutputUnionType,
885+
feature_names: Optional[List[str]] = None,
886+
output_names: Optional[List[str]] = None,
853887
data_distribution: Optional[DataDistribution] = None,
854888
uuid: Optional[_uuid.UUID] = None,
855889
timeout: Optional[float] = None,
@@ -865,6 +899,10 @@ def counterfactual_prediction(
865899
outputs : {}
866900
The desired model outputs to be searched for in the counterfactual explanation.
867901
These can take the form of a: {}
902+
feature_names: Optional[List[str]]
903+
The names of the features, in the case where the feature object does not contain them
904+
output_names: Optional[List[str]]
905+
The names of the outputs, in the case where the outputobject does not contain them
868906
data_distribution : Optional[:class:`DataDistribution`]
869907
The :class:`DataDistribution` to use when sampling the inputs.
870908
uuid : Optional[:class:`_uuid.UUID`]
@@ -878,8 +916,8 @@ def counterfactual_prediction(
878916
timeout = Long(timeout)
879917

880918
return CounterfactualPrediction(
881-
one_input_convert(input_features),
882-
one_output_convert(outputs),
919+
one_input_convert(input_features, feature_names),
920+
one_output_convert(outputs, output_names),
883921
data_distribution,
884922
uuid,
885923
timeout,

src/trustyai/utils/data_conversions.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from itertools import filterfalse
77

88
import trustyai.model
9-
from trustyai.model.domain import feature_domain
109
from org.kie.trustyai.explainability.model import (
1110
Dataframe,
1211
Feature,
@@ -184,21 +183,29 @@ def domain_insertion(
184183

185184
# === input functions ==============================================================================
186185
def one_input_convert(
187-
python_inputs: OneInputUnionType, feature_domains: FeatureDomain = None
186+
python_inputs: OneInputUnionType,
187+
feature_names: Optional[List[str]] = None,
188+
feature_domains: Optional[List[FeatureDomain]] = None,
188189
) -> PredictionInput:
189190
"""Convert an object of OneInputUnionType into a PredictionInput."""
190191
if isinstance(python_inputs, (int, float, np.number)):
191192
python_inputs = np.array([[python_inputs]])
192-
pi = numpy_to_prediction_object(python_inputs, trustyai.model.feature)[0]
193+
pi = numpy_to_prediction_object(
194+
python_inputs, trustyai.model.feature, names=feature_names
195+
)[0]
193196
elif isinstance(python_inputs, list) and all(
194197
(isinstance(x, (int, float, np.number)) for x in python_inputs)
195198
):
196199
python_inputs = np.array(python_inputs).reshape(1, -1)
197-
pi = numpy_to_prediction_object(python_inputs, trustyai.model.feature)[0]
200+
pi = numpy_to_prediction_object(
201+
python_inputs, trustyai.model.feature, names=feature_names
202+
)[0]
198203
elif isinstance(python_inputs, np.ndarray):
199204
if len(python_inputs.shape) == 1:
200205
python_inputs = python_inputs.reshape(1, -1)
201-
pi = numpy_to_prediction_object(python_inputs, trustyai.model.feature)[0]
206+
pi = numpy_to_prediction_object(
207+
python_inputs, trustyai.model.feature, names=feature_names
208+
)[0]
202209
elif isinstance(python_inputs, pd.DataFrame):
203210
pi = df_to_prediction_object(python_inputs, trustyai.model.feature)[0]
204211
elif isinstance(python_inputs, pd.Series):
@@ -217,13 +224,17 @@ def one_input_convert(
217224

218225

219226
def many_inputs_convert(
220-
python_inputs: ManyInputsUnionType, feature_domains: List[FeatureDomain] = None
227+
python_inputs: ManyInputsUnionType,
228+
feature_names: Optional[List[str]] = None,
229+
feature_domains: Optional[List[FeatureDomain]] = None,
221230
) -> List[PredictionInput]:
222231
"""Convert an object of ManyInputsUnionType into a List[PredictionInput]"""
223232
if isinstance(python_inputs, np.ndarray):
224233
if len(python_inputs.shape) == 1:
225234
python_inputs = python_inputs.reshape(1, -1)
226-
pis = numpy_to_prediction_object(python_inputs, trustyai.model.feature)
235+
pis = numpy_to_prediction_object(
236+
python_inputs, trustyai.model.feature, names=feature_names
237+
)
227238
elif isinstance(python_inputs, pd.DataFrame):
228239
pis = df_to_prediction_object(python_inputs, trustyai.model.feature)
229240
else:
@@ -236,20 +247,28 @@ def many_inputs_convert(
236247

237248

238249
# === output functions =============================================================================
239-
def one_output_convert(python_outputs: OneOutputUnionType) -> PredictionOutput:
250+
def one_output_convert(
251+
python_outputs: OneOutputUnionType, names: Optional[List[str]] = None
252+
) -> PredictionOutput:
240253
"""Convert an object of OneOutputUnionType into a PredictionOutput"""
241254
if isinstance(python_outputs, (int, np.integer, float, np.inexact)):
242255
python_outputs = np.array([[python_outputs]])
243-
po = numpy_to_prediction_object(python_outputs, trustyai.model.output)[0]
256+
po = numpy_to_prediction_object(
257+
python_outputs, trustyai.model.output, names=names
258+
)[0]
244259
elif isinstance(python_outputs, list) and all(
245260
(isinstance(x, (int, float, np.number)) for x in python_outputs)
246261
):
247262
python_outputs = np.array(python_outputs).reshape(1, -1)
248-
po = numpy_to_prediction_object(python_outputs, trustyai.model.output)[0]
263+
po = numpy_to_prediction_object(
264+
python_outputs, trustyai.model.output, names=names
265+
)[0]
249266
elif isinstance(python_outputs, np.ndarray):
250267
if len(python_outputs.shape) == 1:
251268
python_outputs = python_outputs.reshape(1, -1)
252-
po = numpy_to_prediction_object(python_outputs, trustyai.model.output)[0]
269+
po = numpy_to_prediction_object(
270+
python_outputs, trustyai.model.output, names=names
271+
)[0]
253272
elif isinstance(python_outputs, pd.DataFrame):
254273
po = df_to_prediction_object(python_outputs, trustyai.model.output)[0]
255274
elif isinstance(python_outputs, pd.Series):
@@ -265,13 +284,15 @@ def one_output_convert(python_outputs: OneOutputUnionType) -> PredictionOutput:
265284

266285

267286
def many_outputs_convert(
268-
python_outputs: ManyOutputsUnionType,
287+
python_outputs: ManyOutputsUnionType, names: Optional[List[str]] = None
269288
) -> List[PredictionOutput]:
270289
"""Convert an object of ManyOutputsUnionType into a List[PredictionOutput]"""
271290
if isinstance(python_outputs, np.ndarray):
272291
if len(python_outputs.shape) == 1:
273292
python_outputs = python_outputs.reshape(1, -1)
274-
return numpy_to_prediction_object(python_outputs, trustyai.model.output)
293+
return numpy_to_prediction_object(
294+
python_outputs, trustyai.model.output, names=names
295+
)
275296
if isinstance(python_outputs, pd.DataFrame):
276297
return df_to_prediction_object(python_outputs, trustyai.model.output)
277298
# fallback case is List[PredictionOutput]

tests/general/test_counterfactualexplainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ def test_counterfactual_with_domain_argument_overwrite():
166166
a warning"""
167167
np.random.seed(0)
168168
data = np.random.rand(1, 5)
169-
domained_inputs = one_input_convert(data, [feature_domain((-10, 10)) for _ in range(5)])
169+
domained_inputs = one_input_convert(data, feature_domains=[feature_domain((-10, 10)) for _ in range(5)])
170170
model_weights = np.random.rand(5)
171171
model = Model(lambda x: np.dot(x, model_weights))
172172
explainer = CounterfactualExplainer(steps=10_000)

tests/general/test_limeexplainer.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,29 @@ def test_lime_as_html():
165165
explainer = LimeExplainer()
166166
explainer.explain(inputs=data, outputs=model(data), model=model)
167167
assert True
168+
169+
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
170+
for score in explanation.as_dataframe()["output-0"]['Saliency']:
171+
assert score != 0
172+
173+
174+
def test_lime_numpy():
175+
np.random.seed(0)
176+
data = np.random.rand(101, 5)
177+
model_weights = np.random.rand(5)
178+
predict_function = lambda x: np.stack([np.dot(x, model_weights), 2 * np.dot(x, model_weights)], -1)
179+
fnames = ['f{}'.format(x) for x in "abcde"]
180+
onames = ['o{}'.format(x) for x in "12"]
181+
model = Model(predict_function,
182+
feature_names=fnames,
183+
output_names=onames
184+
)
185+
186+
explainer = LimeExplainer()
187+
explanation = explainer.explain(inputs=data[0], outputs=model(data[0]), model=model)
188+
189+
for oname in onames:
190+
assert oname in explanation.as_dataframe().keys()
191+
for fname in fnames:
192+
assert fname in explanation.as_dataframe()[oname]['Feature'].values
193+

0 commit comments

Comments
 (0)