Skip to content

Commit c165b1e

Browse files
authored
FAI-882: Add kwargs to explainers (#113)
* changed non-critical args to kwargs * fixed missing bullet in lime kwargs docs
1 parent 8704d0b commit c165b1e

File tree

2 files changed

+50
-50
lines changed

2 files changed

+50
-50
lines changed

src/trustyai/explainers/lime.py

Lines changed: 29 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,6 @@ def _get_bokeh_plot_dict(self):
194194
return plot_dict
195195

196196

197-
# pylint: disable=too-many-arguments
198197
class LimeExplainer:
199198
"""*"Which features were most important to the results?"*
200199
@@ -203,47 +202,49 @@ class LimeExplainer:
203202
feature that describe how strongly said feature contributed to the model's output.
204203
"""
205204

206-
def __init__(
207-
self,
208-
perturbations=1,
209-
seed=0,
210-
samples=10,
211-
penalise_sparse_balance=True,
212-
track_counterfactuals=False,
213-
normalise_weights=False,
214-
use_wlr_model=True,
215-
**kwargs,
216-
):
205+
def __init__(self, samples=10, **kwargs):
217206
"""Initialize the :class:`LimeExplainer`.
218207
219208
Parameters
220209
----------
221-
perturbations: int
222-
The starting number of feature perturbations within the explanation process.
223-
seed: int
224-
The random seed to be used.
225210
samples: int
226211
Number of samples to be generated for the local linear model training.
227-
penalise_sparse_balance : bool
228-
Whether to penalise features that are likely to produce linearly inseparable outputs.
229-
This can improve the efficacy and interpretability of the outputted saliencies.
230-
normalise_weights : bool
231-
Whether to normalise the saliencies generated by LIME. If selected, saliencies will be
232-
normalized between 0 and 1.
212+
213+
Keyword Arguments:
214+
* penalise_sparse_balance : bool
215+
(default=True) Whether to penalise features that are likely to produce linearly
216+
inseparable outputs. This can improve the efficacy and interpretability of the
217+
outputted saliencies.
218+
* normalise_weights : bool
219+
(default=False) Whether to normalise the saliencies generated by LIME. If selected,
220+
saliencies will be normalized between 0 and 1.
221+
* use_wlr_model : bool
222+
(default=True) Whether to use a weighted linear regression as the LIME explanatory
223+
model. If `false`, a multilayer perceptron is used, which generally has a slower
224+
runtime,
225+
* seed: int
226+
(default=0) The random seed to be used.
227+
* perturbations: int
228+
(default=1) The starting number of feature perturbations within the explanation
229+
process.
230+
* trackCounterfactuals : bool
231+
(default=False) Keep track of produced byproduct counterfactuals during LIME run.
233232
"""
234233
self._jrandom = Random()
235-
self._jrandom.setSeed(seed)
234+
self._jrandom.setSeed(kwargs.get("seed", 0))
236235

237236
self._lime_config = (
238237
LimeConfig()
239-
.withNormalizeWeights(normalise_weights)
240-
.withPerturbationContext(PerturbationContext(self._jrandom, perturbations))
238+
.withNormalizeWeights(kwargs.get("normalise_weights", False))
239+
.withPerturbationContext(
240+
PerturbationContext(self._jrandom, kwargs.get("perturbations", 1))
241+
)
241242
.withSamples(samples)
242243
.withEncodingParams(EncodingParams(0.07, 0.3))
243244
.withAdaptiveVariance(True)
244-
.withPenalizeBalanceSparse(penalise_sparse_balance)
245-
.withUseWLRLinearModel(use_wlr_model)
246-
.withTrackCounterfactuals(track_counterfactuals)
245+
.withPenalizeBalanceSparse(kwargs.get("penalise_sparse_balance", True))
246+
.withUseWLRLinearModel(kwargs.get("use_wlr_model", True))
247+
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
247248
)
248249

249250
self._explainer = _LimeExplainer(self._lime_config)

src/trustyai/explainers/shap.py

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Explainers.shap module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3-
# pylint: disable = unused-argument, consider-using-f-string, invalid-name, too-many-arguments
3+
# pylint: disable = unused-argument, consider-using-f-string, invalid-name
44
from typing import Dict, Optional, List, Union
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
@@ -434,11 +434,7 @@ class SHAPExplainer:
434434
def __init__(
435435
self,
436436
background: Union[np.ndarray, pd.DataFrame, List[PredictionInput]],
437-
samples=None,
438-
batch_size=20,
439-
seed=0,
440437
link_type: Optional[_ShapConfig.LinkType] = None,
441-
track_counterfactuals=False,
442438
**kwargs,
443439
):
444440
r"""Initialize the :class:`SHAPxplainer`.
@@ -449,23 +445,26 @@ def __init__(
449445
or List[:class:`PredictionInput]
450446
The set of background datapoints as an array, dataframe of shape
451447
``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
452-
samples: int
453-
The number of samples to use when computing SHAP values. Higher values will increase
454-
explanation accuracy, at the cost of runtime.
455-
batch_size: int
456-
The number of batches passed to the PredictionProvider at once. When using a
457-
:class:`~Model` in the :func:`explain` function, this parameter has no effect. With an
458-
:class:`~ArrowModel`, `batch_sizes` of around
459-
:math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
460-
performance gains.
461-
seed: int
462-
The random seed to be used when generating explanations.
463448
link_type : :obj:`~_ShapConfig.LinkType`
464449
A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
465450
or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
466451
probability, choosing the ``LOGIT`` link will rescale explanations into log-odds units.
467452
Otherwise, choose ``IDENTITY``.
468-
453+
Keyword Arguments:
454+
* samples: int
455+
(default=None) The number of samples to use when computing SHAP values. Higher
456+
values will increase explanation accuracy, at the cost of runtime. If none,
457+
samples will equal 2048 + 2*n_features
458+
* seed: int
459+
(default=0) The random seed to be used when generating explanations.
460+
* batchSize: int
461+
(default=20) The number of batches passed to the PredictionProvider at once.
462+
When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
463+
If `arrow=True`, `batch_sizes` of around
464+
:math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
465+
performance gains.
466+
* trackCounterfactuals : bool
467+
(default=False) Keep track of produced byproduct counterfactuals during SHAP run.
469468
Returns
470469
-------
471470
:class:`~SHAPResults`
@@ -474,7 +473,7 @@ def __init__(
474473
if not link_type:
475474
link_type = _ShapConfig.LinkType.IDENTITY
476475
self._jrandom = Random()
477-
self._jrandom.setSeed(seed)
476+
self._jrandom.setSeed(kwargs.get("seed", 0))
478477
perturbation_context = PerturbationContext(self._jrandom, 0)
479478

480479
if isinstance(background, np.ndarray):
@@ -491,13 +490,13 @@ def __init__(
491490
self._configbuilder = (
492491
_ShapConfig.builder()
493492
.withLink(link_type)
494-
.withBatchSize(batch_size)
493+
.withBatchSize(kwargs.get("batch_size", 20))
495494
.withPC(perturbation_context)
496495
.withBackground(self.background)
497-
.withTrackCounterfactuals(track_counterfactuals)
496+
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
498497
)
499-
if samples is not None:
500-
self._configbuilder.withNSamples(JInt(samples))
498+
if kwargs.get("samples") is not None:
499+
self._configbuilder.withNSamples(JInt(kwargs["samples"]))
501500
self._config = self._configbuilder.build()
502501
self._explainer = _ShapKernelExplainer(self._config)
503502

0 commit comments

Comments
 (0)