40
40
PredictionProvider ,
41
41
Saliency ,
42
42
PerturbationContext ,
43
+ PredictionInputsDataDistribution ,
43
44
)
44
45
45
46
from java .util import Random
@@ -233,14 +234,11 @@ class LimeExplainer:
233
234
feature that describe how strongly said feature contributed to the model's output.
234
235
"""
235
236
236
- def __init__ (self , samples = 10 , ** kwargs ):
237
+ def __init__ (self , ** kwargs ):
237
238
r"""Initialize the :class:`LimeExplainer`.
238
239
239
240
Parameters
240
241
----------
241
- samples: int
242
- Number of samples to be generated for the local linear model training.
243
-
244
242
Keyword Arguments:
245
243
* penalise_sparse_balance : bool
246
244
(default= ``True``) Whether to penalise features that are likely to produce linearly
@@ -260,21 +258,58 @@ def __init__(self, samples=10, **kwargs):
260
258
process.
261
259
* trackCounterfactuals : bool
262
260
(default= ``False``) Keep track of produced byproduct counterfactuals during LIME run.
261
+ * samples: int
262
+ (default= ``300``) Number of samples to be generated for the local linear model training.
263
+ * encoding_params: Union[list, tuple]
264
+ (default= ``(0.07, 0.3)``) Lime encoding parameters, as a tuple/list of two float numbers:
265
+ - encoding_params[0] is the width of the Gaussian filter for clustering number features.
266
+ - encoding_params[1] is the threshold for clustering number features.
267
+ * data_distribution: PredictionInputsDataDistribution
268
+ (default= ``PredictionInputsDataDistribution([])``) Data distribution used to find better feature perturbations
269
+ * features: int
270
+ (default= ``6``) Number of feature to select from the original set of input features
271
+ * retries: int
272
+ (default= ``3``) Number of retries performed by LIME to find a separable dataset
273
+ * dataset_minimum: int
274
+ (default= ``10``) Minimum number of samples retained by the proximity filter to be acceptable
275
+ * separable_dataset_ratio: float
276
+ (default= ``0.1``) Minimum portion of the encoded dataset that needs to have a different label
277
+ * kernel_width: float
278
+ (default= ``0.5``) Width of the proximity kernel
279
+ * proximity_threshold: float
280
+ (default= ``0.83``) Proximity threshold used to retain close samples
281
+ * adapt_dataset_variance: bool
282
+ (default= ``True``) Whether LIME should try to increase the perturbation variance in subsequent retries
283
+ * feature_selection: bool
284
+ (default= ``True``) Whether LIME should generate saliency for to the most important features only
285
+ * filter_interpretable: bool
286
+ (default= ``False``) Whether the proximity filter should happen in the interpretable space
263
287
264
288
"""
265
289
self ._jrandom = Random ()
266
290
self ._jrandom .setSeed (kwargs .get ("seed" , 0 ))
267
-
291
+ ep = kwargs . get ( "encoding_params" , ( 0.07 , 0.3 ))
268
292
self ._lime_config = (
269
293
LimeConfig ()
270
294
.withNormalizeWeights (kwargs .get ("normalise_weights" , False ))
271
295
.withPerturbationContext (
272
296
PerturbationContext (self ._jrandom , kwargs .get ("perturbations" , 1 ))
273
297
)
274
- .withSamples (samples )
275
- .withEncodingParams (EncodingParams (0.07 , 0.3 ))
276
- .withAdaptiveVariance (True )
298
+ .withSamples (kwargs .get ("samples" , 300 ))
299
+ .withDataDistribution (
300
+ kwargs .get ("data_distribution" , PredictionInputsDataDistribution ([]))
301
+ )
302
+ .withNoOfFeatures (kwargs .get ("features" , 6 ))
303
+ .withRetries (kwargs .get ("retries" , 3 ))
304
+ .withProximityFilteredDatasetMinimum (kwargs .get ("dataset_minimum" , 10 ))
305
+ .withSeparableDatasetRatio (kwargs .get ("separable_dataset_ratio" , 0.1 ))
306
+ .withProximityKernelWidth (kwargs .get ("kernel_width" , 0.5 ))
307
+ .withProximityThreshold (kwargs .get ("proximity_threshold" , 0.83 ))
308
+ .withEncodingParams (EncodingParams (ep [0 ], ep [1 ]))
309
+ .withAdaptiveVariance (kwargs .get ("adapt_dataset_variance" , True ))
310
+ .withFeatureSelection (kwargs .get ("feature_selection" , True ))
277
311
.withPenalizeBalanceSparse (kwargs .get ("penalise_sparse_balance" , True ))
312
+ .withFilterInterpretable (kwargs .get ("filter_interpretable" , False ))
278
313
.withUseWLRLinearModel (kwargs .get ("use_wlr_model" , True ))
279
314
.withTrackCounterfactuals (kwargs .get ("track_counterfactuals" , False ))
280
315
)
0 commit comments