1
1
"""Explainers.shap module"""
2
2
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3
3
# pylint: disable = unused-argument, consider-using-f-string, invalid-name
4
- from typing import Dict , Optional , List , Union
4
+ from typing import Dict , Optional
5
5
import matplotlib .pyplot as plt
6
6
import matplotlib as mpl
7
7
from bokeh .models import ColumnDataSource , HoverTool
21
21
output_html ,
22
22
feature_html ,
23
23
)
24
-
25
24
from trustyai .model import (
26
- feature ,
27
- Dataset ,
28
- PredictionInput ,
29
25
simple_prediction ,
30
- PredUnionType ,
26
+ )
27
+ from trustyai .utils .data_conversions import (
28
+ OneInputUnionType ,
29
+ OneOutputUnionType ,
30
+ ManyInputsUnionType ,
31
+ many_inputs_convert ,
32
+ data_conversion_docstring ,
31
33
)
32
34
33
35
from org .kie .trustyai .explainability .local .shap import (
@@ -434,20 +436,19 @@ class SHAPExplainer:
434
436
the outputs, as compared to the background inputs?*
435
437
"""
436
438
439
+ @data_conversion_docstring ("many_inputs" )
437
440
def __init__ (
438
441
self ,
439
- background : Union [ np . ndarray , pd . DataFrame , List [ PredictionInput ]] ,
442
+ background : ManyInputsUnionType ,
440
443
link_type : Optional [_ShapConfig .LinkType ] = None ,
441
444
** kwargs ,
442
445
):
443
446
r"""Initialize the :class:`SHAPxplainer`.
444
447
445
448
Parameters
446
449
----------
447
- background : :class:`numpy.array`, :class:`Pandas.DataFrame`
448
- or List[:class:`PredictionInput]
449
- The set of background datapoints as an array, dataframe of shape
450
- ``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
450
+ background : {}
451
+ The set of background datapoints as a: {}
451
452
link_type : :obj:`~_ShapConfig.LinkType`
452
453
A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
453
454
or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
@@ -464,10 +465,11 @@ def __init__(
464
465
(default=20) The number of batches passed to the PredictionProvider at once.
465
466
When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
466
467
If `arrow=True`, `batch_sizes` of around
467
- :math:`\frac{2000}{ \mathtt{len(background)}}` can produce significant
468
+ :math:`\frac{{ 2000}}{{ \mathtt{{ len(background)}} }}` can produce significant
468
469
performance gains.
469
470
* trackCounterfactuals : bool
470
471
(default=False) Keep track of produced byproduct counterfactuals during SHAP run.
472
+
471
473
Returns
472
474
-------
473
475
:class:`~SHAPResults`
@@ -477,19 +479,9 @@ def __init__(
477
479
link_type = _ShapConfig .LinkType .IDENTITY
478
480
self ._jrandom = Random ()
479
481
self ._jrandom .setSeed (kwargs .get ("seed" , 0 ))
482
+ self .background = many_inputs_convert (background )
480
483
perturbation_context = PerturbationContext (self ._jrandom , 0 )
481
484
482
- if isinstance (background , np .ndarray ):
483
- self .background = Dataset .numpy_to_prediction_object (background , feature )
484
- elif isinstance (background , pd .DataFrame ):
485
- self .background = Dataset .df_to_prediction_object (background , feature )
486
- elif isinstance (background [0 ], PredictionInput ):
487
- self .background = background
488
- else :
489
- raise AttributeError (
490
- "Unsupported background type: {}" .format (type (background ))
491
- )
492
-
493
485
self ._configbuilder = (
494
486
_ShapConfig .builder ()
495
487
.withLink (link_type )
@@ -503,32 +495,22 @@ def __init__(
503
495
self ._config = self ._configbuilder .build ()
504
496
self ._explainer = _ShapKernelExplainer (self ._config )
505
497
498
+ @data_conversion_docstring ("one_input" , "one_output" )
506
499
def explain (
507
- self , inputs : PredUnionType , outputs : PredUnionType , model : PredictionProvider
500
+ self ,
501
+ inputs : OneInputUnionType ,
502
+ outputs : OneOutputUnionType ,
503
+ model : PredictionProvider ,
508
504
) -> SHAPResults :
509
505
"""Produce a SHAP explanation.
510
506
511
507
Parameters
512
508
----------
513
- inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
514
- The input features to the model, as a:
515
-
516
- * Numpy array of shape ``[1, n_features]``
517
- * Pandas DataFrame with 1 row and ``n_features`` columns
518
- * A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
519
- * A TrustyAI :class:`PredictionInput`
520
-
521
- outputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
509
+ inputs : {}
510
+ The input features to the model, as a: {}
511
+ outputs : {}
522
512
The corresponding model outputs for the provided features, that is,
523
- ``outputs = model(input_features)``. These can take the form of a:
524
-
525
- * Numpy array of shape ``[1, n_outputs]``
526
- * Pandas DataFrame with 1 row and ``n_outputs`` columns
527
- * A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
528
- * A TrustyAI :class:`PredictionOutput`
529
- model : :obj:`~trustyai.model.PredictionProvider`
530
- The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or
531
- :class:`~trustyai.model.ArrowModel`.
513
+ ``outputs = model(input_features)``. These can take the form of a: {}
532
514
533
515
Returns
534
516
-------
0 commit comments