1
1
"""Explainers.shap module"""
2
2
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3
- # pylint: disable = unused-argument, consider-using-f-string, invalid-name, too-many-arguments
3
+ # pylint: disable = unused-argument, consider-using-f-string, invalid-name
4
4
from typing import Dict , Optional , List , Union
5
5
import matplotlib .pyplot as plt
6
6
import matplotlib as mpl
@@ -434,11 +434,7 @@ class SHAPExplainer:
434
434
def __init__ (
435
435
self ,
436
436
background : Union [np .ndarray , pd .DataFrame , List [PredictionInput ]],
437
- samples = None ,
438
- batch_size = 20 ,
439
- seed = 0 ,
440
437
link_type : Optional [_ShapConfig .LinkType ] = None ,
441
- track_counterfactuals = False ,
442
438
** kwargs ,
443
439
):
444
440
r"""Initialize the :class:`SHAPxplainer`.
@@ -449,23 +445,26 @@ def __init__(
449
445
or List[:class:`PredictionInput]
450
446
The set of background datapoints as an array, dataframe of shape
451
447
``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
452
- samples: int
453
- The number of samples to use when computing SHAP values. Higher values will increase
454
- explanation accuracy, at the cost of runtime.
455
- batch_size: int
456
- The number of batches passed to the PredictionProvider at once. When using a
457
- :class:`~Model` in the :func:`explain` function, this parameter has no effect. With an
458
- :class:`~ArrowModel`, `batch_sizes` of around
459
- :math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
460
- performance gains.
461
- seed: int
462
- The random seed to be used when generating explanations.
463
448
link_type : :obj:`~_ShapConfig.LinkType`
464
449
A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
465
450
or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
466
451
probability, choosing the ``LOGIT`` link will rescale explanations into log-odds units.
467
452
Otherwise, choose ``IDENTITY``.
468
-
453
+ Keyword Arguments:
454
+ * samples: int
455
+ (default=None) The number of samples to use when computing SHAP values. Higher
456
+ values will increase explanation accuracy, at the cost of runtime. If none,
457
+ samples will equal 2048 + 2*n_features
458
+ * seed: int
459
+ (default=0) The random seed to be used when generating explanations.
460
+ * batchSize: int
461
+ (default=20) The number of batches passed to the PredictionProvider at once.
462
+ When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
463
+ If `arrow=True`, `batch_sizes` of around
464
+ :math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
465
+ performance gains.
466
+ * trackCounterfactuals : bool
467
+ (default=False) Keep track of produced byproduct counterfactuals during SHAP run.
469
468
Returns
470
469
-------
471
470
:class:`~SHAPResults`
@@ -474,7 +473,7 @@ def __init__(
474
473
if not link_type :
475
474
link_type = _ShapConfig .LinkType .IDENTITY
476
475
self ._jrandom = Random ()
477
- self ._jrandom .setSeed (seed )
476
+ self ._jrandom .setSeed (kwargs . get ( " seed" , 0 ) )
478
477
perturbation_context = PerturbationContext (self ._jrandom , 0 )
479
478
480
479
if isinstance (background , np .ndarray ):
@@ -491,13 +490,13 @@ def __init__(
491
490
self ._configbuilder = (
492
491
_ShapConfig .builder ()
493
492
.withLink (link_type )
494
- .withBatchSize (batch_size )
493
+ .withBatchSize (kwargs . get ( " batch_size" , 20 ) )
495
494
.withPC (perturbation_context )
496
495
.withBackground (self .background )
497
- .withTrackCounterfactuals (track_counterfactuals )
496
+ .withTrackCounterfactuals (kwargs . get ( " track_counterfactuals" , False ) )
498
497
)
499
- if samples is not None :
500
- self ._configbuilder .withNSamples (JInt (samples ))
498
+ if kwargs . get ( " samples" ) is not None :
499
+ self ._configbuilder .withNSamples (JInt (kwargs [ " samples" ] ))
501
500
self ._config = self ._configbuilder .build ()
502
501
self ._explainer = _ShapKernelExplainer (self ._config )
503
502
0 commit comments