Skip to content

Commit 7957fa9

Browse files
authored
FAI-416 - TrustyAI explainability benchmarks (#81)
* FAI-416 - lime impact-score benchmark draft * FAI-416 - added LIME and SHAP impact score benchmark with sumskip model * FAI-416 - benchmarks with sumthreshold * FAI-416 - minor adjustments * updates to benchmarks * FAI-416 - reporting mean impact score * FAI-416 - adapted benchmark_commons to work with unified results * FAI-416 - extended benchmarks to check local saliency f1 * FAI-416 - pylint related fixes * FAI-416 - pylint related fixes * FAI-416 - pylint related fixes * FAI-416 - pylint related fixes * FAI-416 - pylint related fixes * FAI-416 - restored correct requirements-dev.txt, fixed import for metrics package * FAI-416 - dropped stale arrow param
1 parent 8514698 commit 7957fa9

File tree

7 files changed

+322
-12
lines changed

7 files changed

+322
-12
lines changed

requirements-dev.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ matplotlib==3.5.1
1616
pandas==1.2.5
1717
pytest-xdist==2.5.0
1818
pytest-benchmark
19-
bokeh==2.4.3
19+
bokeh==2.4.3

src/trustyai/metrics/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# pylint: disable = import-error, invalid-name, wrong-import-order, no-name-in-module
22
"""General model classes"""
33
from trustyai import _default_initializer # pylint: disable=unused-import
4-
from org.kie.kogito.explainability.utils import (
4+
from org.kie.trustyai.explainability.metrics import (
55
ExplainabilityMetrics as _ExplainabilityMetrics,
66
)
77

src/trustyai/metrics/saliency.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
# pylint: disable = import-error
2+
"""Saliency evaluation metrics"""
3+
from typing import Union
4+
5+
from org.apache.commons.lang3.tuple import (
6+
Pair as _Pair,
7+
)
8+
9+
from org.kie.trustyai.explainability.model import (
10+
PredictionInput,
11+
PredictionInputsDataDistribution
12+
)
13+
from org.kie.trustyai.explainability.local import LocalExplainer
14+
15+
from jpype import JObject
16+
17+
from trustyai.model import simple_prediction, PredictionProvider
18+
from trustyai.explainers import SHAPExplainer, LimeExplainer
19+
20+
from . import ExplainabilityMetrics
21+
22+
23+
def impact_score(model: PredictionProvider, pred_input: PredictionInput,
24+
explainer: Union[LimeExplainer, SHAPExplainer],
25+
k: int, is_model_callable: bool = False):
26+
"""
27+
Parameters
28+
----------
29+
model: trustyai.PredictionProvider
30+
the model used to generate predictions
31+
pred_input: trustyai.PredictionInput
32+
the input to the model
33+
explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer]
34+
the explainer to evaluate
35+
k: int
36+
the number of top important features
37+
is_model_callable: bool
38+
whether to directly use model function call or use the predict method
39+
40+
Returns
41+
-------
42+
:float:
43+
impact score metric
44+
"""
45+
if is_model_callable:
46+
output = model(pred_input)
47+
else:
48+
output = model.predict([pred_input])[0].outputs
49+
pred = simple_prediction(pred_input, output)
50+
explanation = explainer.explain(inputs=pred_input, outputs=output, model=model)
51+
saliency = list(explanation.saliency_map().values())[0]
52+
top_k_features = saliency.getTopFeatures(k)
53+
return ExplainabilityMetrics.impactScore(model, pred, top_k_features)
54+
55+
56+
def mean_impact_score(explainer: Union[LimeExplainer, SHAPExplainer],
57+
model: PredictionProvider, data: list, is_model_callable=False, k=2):
58+
"""
59+
Parameters
60+
----------
61+
explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer]
62+
the explainer to evaluate
63+
model: trustyai.PredictionProvider
64+
the model used to generate predictions
65+
data: list[list[trustyai.model.Feature]]
66+
the inputs to calculate the metric for
67+
is_model_callable: bool
68+
whether to directly use model function call or use the predict method
69+
k: int
70+
the number of top important features
71+
72+
Returns
73+
-------
74+
:float:
75+
the mean impact score metric across all inputs
76+
"""
77+
m_is = 0
78+
for features in data:
79+
m_is += impact_score(model, features, explainer, k, is_model_callable=is_model_callable)
80+
return m_is / len(data)
81+
82+
83+
def classification_fidelity(explainer: Union[LimeExplainer, SHAPExplainer],
84+
model: PredictionProvider, inputs: list,
85+
is_model_callable: bool = False):
86+
"""
87+
Parameters
88+
----------
89+
explainer: Union[trustyai.explainers.LimeExplainer, trustyai.explainers.SHAPExplainer]
90+
the explainer to evaluate
91+
model: trustyai.PredictionProvider
92+
the model used to generate predictions
93+
inputs: list[list[trustyai.model.Feature]]
94+
the inputs to calculate the metric for
95+
is_model_callable: bool
96+
whether to directly use model function call or use the predict method
97+
98+
Returns
99+
-------
100+
:float:
101+
the classification fidelity metric
102+
"""
103+
pairs = []
104+
for c_input in inputs:
105+
if is_model_callable:
106+
output = model(c_input)
107+
else:
108+
output = model.predict([c_input])[0].outputs
109+
explanation = explainer.explain(inputs=c_input, outputs=output, model=model)
110+
saliency = list(explanation.saliency_map().values())[0]
111+
pairs.append(_Pair.of(saliency, simple_prediction(c_input, output)))
112+
return ExplainabilityMetrics.classificationFidelity(pairs)
113+
114+
# pylint: disable = too-many-arguments
115+
def local_saliency_f1(output_name: str, model: PredictionProvider,
116+
explainer: Union[LimeExplainer, SHAPExplainer],
117+
distribution: PredictionInputsDataDistribution, k: int,
118+
chunk_size: int):
119+
"""
120+
Parameters
121+
----------
122+
output_name: str
123+
the name of the output to calculate the metric for
124+
model: trustyai.PredictionProvider
125+
the model used to generate predictions
126+
explainer: Union[trustyai.explainers.LIMEExplainer, trustyai.explainers.SHAPExplainer,
127+
trustyai.explainers.LocalExplainer]
128+
the explainer to evaluate
129+
distribution: org.kie.trustyai.explainability.model.PredictionInputsDataDistribution
130+
the data distribution to fetch the inputs from
131+
k: int
132+
the number of top important features
133+
chunk_size: int
134+
the chunk of inputs to fetch fro the distribution
135+
136+
Returns
137+
-------
138+
:float:
139+
the local saliency f1 metric
140+
"""
141+
if not isinstance(explainer, LocalExplainer):
142+
# pylint: disable = protected-access
143+
local_explainer = JObject(explainer._explainer, LocalExplainer)
144+
else:
145+
local_explainer = explainer
146+
return ExplainabilityMetrics.getLocalSaliencyF1(output_name, model, local_explainer,
147+
distribution, k, chunk_size)

tests/benchmarks/benchmark.py

Lines changed: 154 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@
44
import sys
55
import pytest
66
import time
7+
import numpy as np
8+
9+
from trustyai.explainers import LimeExplainer, SHAPExplainer
10+
from trustyai.model import feature, PredictionInput
11+
from trustyai.utils import TestModels
12+
from trustyai.metrics.saliency import mean_impact_score, classification_fidelity, local_saliency_f1
13+
14+
from org.kie.trustyai.explainability.model import (
15+
PredictionInputsDataDistribution,
16+
)
717

818
myPath = os.path.dirname(os.path.abspath(__file__))
919
sys.path.insert(0, myPath + "/../general/")
1020

1121
import test_counterfactualexplainer as tcf
12-
import test_limeexplainer as tlime
13-
1422

1523
@pytest.mark.benchmark(
1624
group="counterfactuals", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
@@ -35,9 +43,147 @@ def test_counterfactual_match_python_model(benchmark):
3543
"""Counterfactual match (Python model)"""
3644
benchmark(tcf.test_counterfactual_match_python_model)
3745

38-
# @pytest.mark.benchmark(
39-
# group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
40-
# )
41-
# def test_non_empty_input(benchmark):
42-
# """Counterfactual match (Python model)"""
43-
# benchmark(tlime.test_non_empty_input)
46+
47+
@pytest.mark.benchmark(
48+
group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
49+
)
50+
def test_sumskip_lime_impact_score_at_2(benchmark):
51+
no_of_features = 10
52+
np.random.seed(0)
53+
explainer = LimeExplainer()
54+
model = TestModels.getSumSkipModel(0)
55+
data = []
56+
for i in range(100):
57+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)])
58+
benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data)
59+
benchmark(mean_impact_score, explainer, model, data)
60+
61+
62+
@pytest.mark.benchmark(
63+
group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
64+
)
65+
def test_sumskip_shap_impact_score_at_2(benchmark):
66+
no_of_features = 10
67+
np.random.seed(0)
68+
background = []
69+
for i in range(10):
70+
background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)]))
71+
explainer = SHAPExplainer(background, samples=10000)
72+
model = TestModels.getSumSkipModel(0)
73+
data = []
74+
for i in range(100):
75+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in range(no_of_features)])
76+
benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data)
77+
benchmark(mean_impact_score, explainer, model, data)
78+
79+
80+
@pytest.mark.benchmark(
81+
group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
82+
)
83+
def test_sumthreshold_lime_impact_score_at_2(benchmark):
84+
no_of_features = 10
85+
np.random.seed(0)
86+
explainer = LimeExplainer()
87+
center = 100.0
88+
epsilon = 10.0
89+
model = TestModels.getSumThresholdModel(center, epsilon)
90+
data = []
91+
for i in range(100):
92+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])
93+
benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data)
94+
benchmark(mean_impact_score, explainer, model, data)
95+
96+
97+
@pytest.mark.benchmark(
98+
group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
99+
)
100+
def test_sumthreshold_shap_impact_score_at_2(benchmark):
101+
no_of_features = 10
102+
np.random.seed(0)
103+
background = []
104+
for i in range(100):
105+
background.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]))
106+
explainer = SHAPExplainer(background, samples=10000)
107+
center = 100.0
108+
epsilon = 10.0
109+
model = TestModels.getSumThresholdModel(center, epsilon)
110+
data = []
111+
for i in range(100):
112+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])
113+
benchmark.extra_info['metric'] = mean_impact_score(explainer, model, data)
114+
benchmark(mean_impact_score, explainer, model, data)
115+
116+
117+
@pytest.mark.benchmark(
118+
group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
119+
)
120+
def test_lime_fidelity(benchmark):
121+
no_of_features = 10
122+
np.random.seed(0)
123+
explainer = LimeExplainer()
124+
model = TestModels.getEvenSumModel(0)
125+
data = []
126+
for i in range(100):
127+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)])
128+
benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data)
129+
benchmark(classification_fidelity, explainer, model, data)
130+
131+
132+
@pytest.mark.benchmark(
133+
group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
134+
)
135+
def test_shap_fidelity(benchmark):
136+
no_of_features = 10
137+
np.random.seed(0)
138+
background = []
139+
for i in range(10):
140+
background.append(PredictionInput(
141+
[feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in
142+
range(no_of_features)]))
143+
explainer = SHAPExplainer(background, samples=10000)
144+
model = TestModels.getEvenSumModel(0)
145+
data = []
146+
for i in range(100):
147+
data.append([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in
148+
range(no_of_features)])
149+
benchmark.extra_info['metric'] = classification_fidelity(explainer, model, data)
150+
benchmark(classification_fidelity, explainer, model, data)
151+
152+
153+
@pytest.mark.benchmark(
154+
group="lime", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
155+
)
156+
def test_lime_local_saliency_f1(benchmark):
157+
no_of_features = 10
158+
np.random.seed(0)
159+
explainer = LimeExplainer()
160+
model = TestModels.getEvenSumModel(0)
161+
output_name = "sum-even-but0"
162+
data = []
163+
for i in range(100):
164+
data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]))
165+
distribution = PredictionInputsDataDistribution(data)
166+
benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10)
167+
benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10)
168+
169+
170+
@pytest.mark.benchmark(
171+
group="shap", min_rounds=10, timer=time.time, disable_gc=True, warmup=True
172+
)
173+
def test_shap_local_saliency_f1(benchmark):
174+
no_of_features = 10
175+
np.random.seed(0)
176+
background = []
177+
for i in range(10):
178+
background.append(PredictionInput(
179+
[feature(name=f"f-num{i}", value=np.random.randint(-10, 10), dtype="number") for i in
180+
range(no_of_features)]))
181+
explainer = SHAPExplainer(background, samples=10000)
182+
model = TestModels.getEvenSumModel(0)
183+
output_name = "sum-even-but0"
184+
data = []
185+
for i in range(100):
186+
data.append(PredictionInput([feature(name=f"f-num{i}", value=np.random.randint(-100, 100), dtype="number") for i in range(no_of_features)]))
187+
distribution = PredictionInputsDataDistribution(data)
188+
benchmark.extra_info['metric'] = local_saliency_f1(output_name, model, explainer, distribution, 2, 10)
189+
benchmark(local_saliency_f1, output_name, model, explainer, distribution, 2, 10)

tests/benchmarks/benchmark_common.py

Whitespace-only changes.

tests/general/test_limeexplainer.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
from trustyai.explainers import LimeExplainer
99
from trustyai.utils import TestModels
10-
from trustyai.model import feature, Model
10+
from trustyai.model import feature, Model, simple_prediction
11+
from trustyai.metrics import ExplainabilityMetrics
1112

1213
from org.kie.trustyai.explainability.local import (
1314
LocalExplanationException,
@@ -126,3 +127,19 @@ def test_lime_v2():
126127
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
127128
for score in explanation.as_dataframe()["output-0_score"]:
128129
assert score != 0
130+
131+
def test_impact_score():
132+
np.random.seed(0)
133+
data = pd.DataFrame(np.random.rand(1, 5))
134+
model_weights = np.random.rand(5)
135+
predict_function = lambda x: np.dot(x.values, model_weights)
136+
model = Model(predict_function, dataframe_input=True)
137+
output = model(data)
138+
pred = simple_prediction(data, output)
139+
explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False)
140+
explanation = explainer.explain(inputs=data, outputs=output, model=model)
141+
saliency = list(explanation.saliency_map().values())[0]
142+
top_features_t = saliency.getTopFeatures(2)
143+
impact = ExplainabilityMetrics.impactScore(model, pred, top_features_t)
144+
assert impact > 0
145+
return impact

0 commit comments

Comments
 (0)