Skip to content

Commit 65162d2

Browse files
authored
FAI-922 - expose LimeConfig options in LimeExplainer (#138)
* FAI-922 - expose LimeConfig options in LimeExplainer * FAI-922 - expose LimeConfig options in LimeExplainer * FAI-922 - format check
1 parent 0e89a54 commit 65162d2

File tree

1 file changed

+43
-8
lines changed

1 file changed

+43
-8
lines changed

src/trustyai/explainers/lime.py

Lines changed: 43 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PredictionProvider,
4141
Saliency,
4242
PerturbationContext,
43+
PredictionInputsDataDistribution,
4344
)
4445

4546
from java.util import Random
@@ -233,14 +234,11 @@ class LimeExplainer:
233234
feature that describe how strongly said feature contributed to the model's output.
234235
"""
235236

236-
def __init__(self, samples=10, **kwargs):
237+
def __init__(self, **kwargs):
237238
r"""Initialize the :class:`LimeExplainer`.
238239
239240
Parameters
240241
----------
241-
samples: int
242-
Number of samples to be generated for the local linear model training.
243-
244242
Keyword Arguments:
245243
* penalise_sparse_balance : bool
246244
(default= ``True``) Whether to penalise features that are likely to produce linearly
@@ -260,21 +258,58 @@ def __init__(self, samples=10, **kwargs):
260258
process.
261259
* trackCounterfactuals : bool
262260
(default= ``False``) Keep track of produced byproduct counterfactuals during LIME run.
261+
* samples: int
262+
(default= ``300``) Number of samples to be generated for the local linear model training.
263+
* encoding_params: Union[list, tuple]
264+
(default= ``(0.07, 0.3)``) Lime encoding parameters, as a tuple/list of two float numbers:
265+
- encoding_params[0] is the width of the Gaussian filter for clustering number features.
266+
- encoding_params[1] is the threshold for clustering number features.
267+
* data_distribution: PredictionInputsDataDistribution
268+
(default= ``PredictionInputsDataDistribution([])``) Data distribution used to find better feature perturbations
269+
* features: int
270+
(default= ``6``) Number of feature to select from the original set of input features
271+
* retries: int
272+
(default= ``3``) Number of retries performed by LIME to find a separable dataset
273+
* dataset_minimum: int
274+
(default= ``10``) Minimum number of samples retained by the proximity filter to be acceptable
275+
* separable_dataset_ratio: float
276+
(default= ``0.1``) Minimum portion of the encoded dataset that needs to have a different label
277+
* kernel_width: float
278+
(default= ``0.5``) Width of the proximity kernel
279+
* proximity_threshold: float
280+
(default= ``0.83``) Proximity threshold used to retain close samples
281+
* adapt_dataset_variance: bool
282+
(default= ``True``) Whether LIME should try to increase the perturbation variance in subsequent retries
283+
* feature_selection: bool
284+
(default= ``True``) Whether LIME should generate saliency for to the most important features only
285+
* filter_interpretable: bool
286+
(default= ``False``) Whether the proximity filter should happen in the interpretable space
263287
264288
"""
265289
self._jrandom = Random()
266290
self._jrandom.setSeed(kwargs.get("seed", 0))
267-
291+
ep = kwargs.get("encoding_params", (0.07, 0.3))
268292
self._lime_config = (
269293
LimeConfig()
270294
.withNormalizeWeights(kwargs.get("normalise_weights", False))
271295
.withPerturbationContext(
272296
PerturbationContext(self._jrandom, kwargs.get("perturbations", 1))
273297
)
274-
.withSamples(samples)
275-
.withEncodingParams(EncodingParams(0.07, 0.3))
276-
.withAdaptiveVariance(True)
298+
.withSamples(kwargs.get("samples", 300))
299+
.withDataDistribution(
300+
kwargs.get("data_distribution", PredictionInputsDataDistribution([]))
301+
)
302+
.withNoOfFeatures(kwargs.get("features", 6))
303+
.withRetries(kwargs.get("retries", 3))
304+
.withProximityFilteredDatasetMinimum(kwargs.get("dataset_minimum", 10))
305+
.withSeparableDatasetRatio(kwargs.get("separable_dataset_ratio", 0.1))
306+
.withProximityKernelWidth(kwargs.get("kernel_width", 0.5))
307+
.withProximityThreshold(kwargs.get("proximity_threshold", 0.83))
308+
.withEncodingParams(EncodingParams(ep[0], ep[1]))
309+
.withAdaptiveVariance(kwargs.get("adapt_dataset_variance", True))
310+
.withFeatureSelection(kwargs.get("feature_selection", True))
277311
.withPenalizeBalanceSparse(kwargs.get("penalise_sparse_balance", True))
312+
.withFilterInterpretable(kwargs.get("filter_interpretable", False))
278313
.withUseWLRLinearModel(kwargs.get("use_wlr_model", True))
279314
.withTrackCounterfactuals(kwargs.get("track_counterfactuals", False))
280315
)

0 commit comments

Comments
 (0)