Skip to content

Commit 47a8ffb

Browse files
authored
FAI-884: Add SHAP background generators to bindings (#117)
* unified input/output types, conversion functions, and docstrings * Initial creation of background generators * fixed many-x conversions for 1d arrays, added more test cases * fixing input -> output typo * fixed randozm -> random typo * further fleshed out cf generation * added feature domains into data conversions * extended cf generation tests * changed default parameters, linting
1 parent 1563850 commit 47a8ffb

File tree

3 files changed

+244
-0
lines changed

3 files changed

+244
-0
lines changed

src/trustyai/explainers/lime.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
data_conversion_docstring,
2626
OneOutputUnionType,
2727
)
28+
2829
from .explanation_results import SaliencyResults
2930
from trustyai.model import simple_prediction
3031

src/trustyai/explainers/shap.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,14 +28,23 @@
2828
OneInputUnionType,
2929
OneOutputUnionType,
3030
ManyInputsUnionType,
31+
ManyOutputsUnionType,
3132
many_inputs_convert,
3233
data_conversion_docstring,
34+
many_outputs_convert,
3335
)
3436

3537
from org.kie.trustyai.explainability.local.shap import (
3638
ShapConfig as _ShapConfig,
3739
ShapKernelExplainer as _ShapKernelExplainer,
3840
)
41+
42+
from org.kie.trustyai.explainability.local.shap.background import (
43+
RandomGenerator,
44+
KMeansGenerator,
45+
CounterfactualGenerator,
46+
)
47+
3948
from org.kie.trustyai.explainability.model import (
4049
PredictionProvider,
4150
Saliency,
@@ -410,6 +419,151 @@ def _get_bokeh_plot_dict(self):
410419
}
411420

412421

422+
class BackgroundGenerator:
423+
r"""Generate a background for the SHAP explainer via one of three algorithms:
424+
425+
* `sample`: Randomly sample a set of provided points
426+
* `kmeans`: Summarize a set of provided points into k centroids
427+
* `counterfactual`: Generate a set of background points that meet certain criteria
428+
429+
"""
430+
431+
@data_conversion_docstring("many_inputs")
432+
def __init__(self, datapoints: ManyInputsUnionType, feature_domains=None, seed=0):
433+
r"""Initialize the :class:`BackgroundGenerator`.
434+
435+
Parameters
436+
----------
437+
datapoints : {}
438+
The set of datapoints to be used to sample/generate the background, as a: {}
439+
seed : int
440+
The random seed to use in the sampling/generation method
441+
"""
442+
self.datapoints = many_inputs_convert(datapoints, feature_domains)
443+
self.feature_domains = feature_domains
444+
self.seed = 0
445+
self._jrandom = Random()
446+
self._jrandom.setSeed(self.seed)
447+
448+
def sample(self, k=100):
449+
r"""Randomly sample datapoints.
450+
451+
Parameters
452+
----------
453+
k : int
454+
The number of datapoints to select
455+
456+
Returns
457+
-------
458+
:list:`PredictionInput`
459+
The background dataset to pass to the :class:`~SHAPExplainer`
460+
"""
461+
perturbation_context = PerturbationContext(self._jrandom, 0)
462+
return RandomGenerator(self.datapoints, perturbation_context).generate(k)
463+
464+
def kmeans(self, k=100):
465+
r"""Use k-means clustering over `datapoints` and return k centroids as the background data
466+
set.
467+
468+
Parameters
469+
----------
470+
k : int
471+
The number of centroids to find
472+
473+
Returns
474+
-------
475+
:list:`PredictionInput`
476+
The background dataset to pass to the :class:`~SHAPExplainer`
477+
"""
478+
return KMeansGenerator(self.datapoints, self.seed).generate(k)
479+
480+
@data_conversion_docstring("many_outputs")
481+
def counterfactual(
482+
self,
483+
goals: ManyOutputsUnionType,
484+
model: PredictionProvider,
485+
k_per_goal=100,
486+
**kwargs,
487+
):
488+
r"""Generate a background via the CounterfactualExplainer. This lets you specify
489+
exact output values that the background dataset conforms to, and thus set the reference
490+
point by which all SHAP values compare. For example, if your model is a regression
491+
model, choosing a counterfactual goal of 0 will create a background dataset where
492+
:math:'f(x) \approx 0 \forall x \in \text{{background}}`, and as such the SHAP values
493+
will compare against zero, which is a useful baseline for regression.
494+
495+
Parameters
496+
----------
497+
goals : {}
498+
The set of background datapoints as a: {}
499+
model : :obj:`~trustyai.model.PredictionProvider`
500+
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model`
501+
k_per_goal : int
502+
The number of background datapoints to generate per goal.
503+
Keyword Arguments:
504+
* k_seeds: int
505+
(default=5) For each goal, a number of starting seeds from `datapoints` are used
506+
to start the search from. These are the `k_seeds` points within `datapoint`
507+
whose corresponding outputs are closet to the goal output. Choose a larger
508+
number to get a more diverse background dataset, but the search might require
509+
larger `max_attempt_count`, `step_count`, and `timeout_seconds` to get good results.
510+
* goal_threshold: float
511+
(default=.01) The distance (percentage) threshold defining whether
512+
a particular output satisfies the goal. Set to 0 to require an exact match, but
513+
this will likey require larger `max_attempt_count`, `step_count`,
514+
and `timeout_seconds` to get good results.
515+
* chain: boolean
516+
(default=False) If chaining is set to `true`, found counterfactual datapoints
517+
will be added to the search seeds for subsequent searches. This is useful when a
518+
range of counterfactual outputs is desired; for example, if the desired goals are
519+
[0, 1, 2, 3], whichever goal is closest to the closest point within `datapoints` will
520+
be searched for first. The found counterfactuals from that search are then included
521+
in the search for the second-closest goal, and so on. This is especially helpful
522+
if the extremes of the goal range are far outside the range produced by the
523+
`datapoints`. If only
524+
* max_attempt_count: int
525+
If no valid counterfactual can be found for a starting seed in the search, the point
526+
is slightly perturbed and search is retried. This parameter sets the maximum
527+
number of perturbation-retry cycles are allowed during generation.
528+
* step_count: int
529+
(default=10,000) The number of datapoints to evaluate during the search
530+
* timeout_seconds: int
531+
(default=30) The maximum number of seconds allowed for each counterfactual search
532+
533+
Returns
534+
-------
535+
:list:`PredictionInput`
536+
The background dataset to pass to the :class:`~SHAPExplainer`
537+
"""
538+
if self.feature_domains is None:
539+
raise AttributeError(
540+
"Feature domains must be passed to perform"
541+
" meaningful counterfactual search"
542+
)
543+
goals_converted = many_outputs_convert(goals)
544+
generator = (
545+
CounterfactualGenerator.builder()
546+
.withModel(model)
547+
.withKSeeds(kwargs.get("k_seeds", 5))
548+
.withRandom(self._jrandom)
549+
.withTimeoutSeconds(kwargs.get("timeout_seconds", 3))
550+
.withStepCount(kwargs.get("step_count", 5_000))
551+
.withGoalThreshold(kwargs.get("goal_threshold", 0.01))
552+
.withMaxAttemptCount(kwargs.get("max_attempt_count", 5))
553+
.build()
554+
)
555+
556+
if len(goals) == 1:
557+
background = generator.generate(
558+
self.datapoints, goals_converted[0], k_per_goal
559+
)
560+
else:
561+
background = generator.generateRange(
562+
self.datapoints, goals_converted, k_per_goal, kwargs.get("chain", False)
563+
)
564+
return background
565+
566+
413567
class SHAPExplainer:
414568
r"""*"By how much did each feature contribute to the outputs?"*
415569
@@ -511,6 +665,8 @@ def explain(
511665
outputs : {}
512666
The corresponding model outputs for the provided features, that is,
513667
``outputs = model(input_features)``. These can take the form of a: {}
668+
model : :obj:`~trustyai.model.PredictionProvider`
669+
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model`
514670
515671
Returns
516672
-------
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
"""SHAP background generation test suite"""
2+
3+
import pytest
4+
import numpy as np
5+
import math
6+
7+
from trustyai.explainers.shap import BackgroundGenerator
8+
from trustyai.model import Model, feature_domain
9+
from trustyai.utils.data_conversions import prediction_object_to_numpy
10+
11+
12+
def test_random_generation():
13+
"""Test that random sampling recovers samples from distribution"""
14+
seed = 0
15+
np.random.seed(seed)
16+
data = np.random.rand(100, 5)
17+
background_ta = BackgroundGenerator(data).sample(5)
18+
background = prediction_object_to_numpy(background_ta)
19+
20+
assert len(background) == 5
21+
for row in background:
22+
assert row in data
23+
24+
25+
def test_kmeans_generation():
26+
"""Test that k-means recovers centroids of well-clustered data"""
27+
28+
seed = 0
29+
clusters = 5
30+
np.random.seed(seed)
31+
32+
data = []
33+
ground_truth = []
34+
for cluster in range(clusters):
35+
data.append(np.random.rand(100 // clusters, 5) + cluster * 10)
36+
ground_truth.append(np.array([cluster * 10] * 5))
37+
data = np.vstack(data)
38+
ground_truth = np.vstack(ground_truth)
39+
background_ta = BackgroundGenerator(data).kmeans(clusters)
40+
background = prediction_object_to_numpy(background_ta)
41+
42+
assert len(background) == 5
43+
for row in background:
44+
ground_truth_idx = math.floor(row[0] / 10)
45+
assert np.linalg.norm(row - ground_truth[ground_truth_idx]) < 2.5
46+
47+
48+
def test_counterfactual_generation_single_goal():
49+
"""Test that cf background meets requirements"""
50+
seed = 0
51+
np.random.seed(seed)
52+
data = np.random.rand(100, 5)
53+
model = Model(lambda x: x.sum(1), arrow=False)
54+
goal = np.array([1.0])
55+
56+
# check that undomained backgrounds are caught
57+
attribute_error_thrown = False
58+
try:
59+
BackgroundGenerator(data).counterfactual(goal, model, 10,)
60+
except AttributeError:
61+
attribute_error_thrown = True
62+
assert attribute_error_thrown
63+
64+
domains = [feature_domain((-10, 10)) for _ in range(5)]
65+
background_ta = BackgroundGenerator(data, domains, seed)\
66+
.counterfactual(goal, model, 5, step_count=5000, timeout_seconds=2)
67+
background = prediction_object_to_numpy(background_ta)
68+
69+
for row in background:
70+
assert np.linalg.norm(goal - model(row.reshape(1, -1))) < .01
71+
72+
73+
def test_counterfactual_generation_multi_goal():
74+
"""Test that cf background meets requirements for multiple goals"""
75+
76+
seed = 0
77+
np.random.seed(seed)
78+
data = np.random.rand(100, 5)
79+
model = Model(lambda x: x.sum(1), arrow=False)
80+
goals = np.arange(1, 10).reshape(-1, 1)
81+
domains = [feature_domain((-10, 10)) for _ in range(5)]
82+
background_ta = BackgroundGenerator(data, domains, seed)\
83+
.counterfactual(goals, model, 1, step_count=5000, timeout_seconds=2, chain=True)
84+
background = prediction_object_to_numpy(background_ta)
85+
86+
for i, goal in enumerate(goals):
87+
assert np.linalg.norm(goal - model(background[i:i+1])) < goal[0]/100

0 commit comments

Comments
 (0)