Skip to content

Commit 3ed9c6b

Browse files
authored
FAI-863: XAI-Bench (#109)
* submodule * linting round 1
1 parent cc569f3 commit 3ed9c6b

File tree

4 files changed

+33
-1
lines changed

4 files changed

+33
-1
lines changed

.gitmodules

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[submodule "tests/benchmarks/trustyai_xai_bench"]
2+
path = tests/benchmarks/trustyai_xai_bench
3+
url = https://github.com/trustyai-explainability/trustyai_xai_bench

src/trustyai/explainers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Explainers module"""
2-
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long
2+
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3+
# pylint: disable = unused-argument
34
from typing import Dict, Optional, List, Union
45
import matplotlib.pyplot as plt
56
import matplotlib as mpl
@@ -425,6 +426,8 @@ def __init__(
425426
penalise_sparse_balance=True,
426427
track_counterfactuals=False,
427428
normalise_weights=False,
429+
use_wlr_model=True,
430+
**kwargs
428431
):
429432
"""Initialize the :class:`LimeExplainer`.
430433
@@ -454,6 +457,7 @@ def __init__(
454457
.withEncodingParams(EncodingParams(0.07, 0.3))
455458
.withAdaptiveVariance(True)
456459
.withPenalizeBalanceSparse(penalise_sparse_balance)
460+
.withUseWLRLinearModel(use_wlr_model)
457461
.withTrackCounterfactuals(track_counterfactuals)
458462
)
459463

@@ -896,6 +900,7 @@ def __init__(
896900
seed=0,
897901
link_type: Optional[_ShapConfig.LinkType] = None,
898902
track_counterfactuals=False,
903+
**kwargs,
899904
):
900905
r"""Initialize the :class:`SHAPxplainer`.
901906

tests/benchmarks/trustyai_xai_bench

Submodule trustyai_xai_bench added at cb90cba

tests/benchmarks/xai_benchmark.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import pytest
2+
from trustyai_xai_bench import run_benchmark_config
3+
4+
5+
@pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False)
6+
def test_level_0(benchmark):
7+
# ~4.5 min
8+
result = benchmark(run_benchmark_config, 0)
9+
benchmark.extra_info['runs'] = result.to_dict('records')
10+
11+
12+
@pytest.mark.skip(reason="full diagnostic benchmark, ~2 hour runtime")
13+
@pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False)
14+
def test_level_1(benchmark):
15+
result = benchmark(run_benchmark_config, 1)
16+
benchmark.extra_info['runs'] = result.to_dict('records')
17+
18+
19+
@pytest.mark.skip(reason="very thorough benchmark, >>2 hour runtime")
20+
@pytest.mark.benchmark(group="xai_bench", min_rounds=1, warmup=False)
21+
def test_level_2(benchmark):
22+
result = benchmark(run_benchmark_config, 2)
23+
benchmark.extra_info['runs'] = result.to_dict('records')

0 commit comments

Comments
 (0)