Skip to content

Commit 9b47f08

Browse files
authored
FAI-893: Make test plots non-blocking by default (#120)
* made default plot tests non-blocking, unless -m blocking tag is called * changed argument to block_plots * linting
1 parent 4a96ae3 commit 9b47f08

File tree

8 files changed

+58
-23
lines changed

8 files changed

+58
-23
lines changed

pyproject.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ build-backend = "setuptools.build_meta"
5858
[tool.setuptools]
5959
package-dir = { "" = "src" }
6060

61+
[tool.pytest.ini_options]
62+
addopts = '-m="not block_plots"'
63+
markers = [
64+
"block_plots: Test plots will block execution of subsequent tests until closed"
65+
]
66+
6167
[tool.setuptools.packages.find]
6268
where = ["src"]
6369

src/trustyai/explainers/counterfactuals.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
118118
"""
119119
return self.as_dataframe().style
120120

121-
def plot(self) -> None:
121+
def plot(self, block=True) -> None:
122122
"""
123123
Plot the counterfactual result.
124124
"""
@@ -140,7 +140,7 @@ def change_colour(value):
140140
x="features", color={"proposed": colour, "original": "black"}
141141
)
142142
plot.set_title("Counterfactual")
143-
plt.show()
143+
plt.show(block=block)
144144

145145

146146
class CounterfactualExplainer:

src/trustyai/explainers/explanation_results.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def saliency_map(self):
2929
"""Return the Saliencies as a dictionary, keyed by output name"""
3030

3131
@abstractmethod
32-
def _matplotlib_plot(self, output_name: str) -> None:
32+
def _matplotlib_plot(self, output_name: str, block: bool) -> None:
3333
"""Plot the saliencies of a particular output in matplotlib"""
3434

3535
@abstractmethod
@@ -44,7 +44,7 @@ def _get_bokeh_plot_dict(self) -> Dict[str, bokeh.models.Plot]:
4444
for output_name in self.saliency_map().keys()
4545
}
4646

47-
def plot(self, output_name=None, render_bokeh=False) -> None:
47+
def plot(self, output_name=None, render_bokeh=False, block=True) -> None:
4848
"""
4949
Plot the found feature saliencies.
5050
@@ -55,15 +55,17 @@ def plot(self, output_name=None, render_bokeh=False) -> None:
5555
be displayed
5656
render_bokeh : bool
5757
(default: false) Whether to render as bokeh (true) or matplotlib (false)
58+
block: bool
59+
(default: true) Whether displaying the plot blocks subsequent code execution
5860
"""
5961
if output_name is None:
6062
for output_name_iterator in self.saliency_map().keys():
6163
if render_bokeh:
6264
show(self._get_bokeh_plot(output_name_iterator))
6365
else:
64-
self._matplotlib_plot(output_name_iterator)
66+
self._matplotlib_plot(output_name_iterator, block)
6567
else:
6668
if render_bokeh:
6769
show(self._get_bokeh_plot(output_name))
6870
else:
69-
self._matplotlib_plot(output_name)
71+
self._matplotlib_plot(output_name, block)

src/trustyai/explainers/lime.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def as_html(self) -> pd.io.formats.style.Styler:
114114
"""
115115
return self.as_dataframe().style
116116

117-
def _matplotlib_plot(self, output_name: str) -> None:
117+
def _matplotlib_plot(self, output_name: str, block=True) -> None:
118118
"""Plot the LIME saliencies."""
119119
with mpl.rc_context(drcp):
120120
dictionary = {}
@@ -140,7 +140,7 @@ def _matplotlib_plot(self, output_name: str) -> None:
140140
)
141141
plt.yticks(range(len(dictionary)), list(dictionary.keys()))
142142
plt.tight_layout()
143-
plt.show()
143+
plt.show(block=block)
144144

145145
def _get_bokeh_plot(self, output_name) -> bokeh.models.Plot:
146146
lime_data_source = pd.DataFrame(

src/trustyai/explainers/shap.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ def _color_feature_values(feature_values, background_vals):
218218
)
219219
return df_dict
220220

221-
def _matplotlib_plot(self, output_name) -> None:
221+
def _matplotlib_plot(self, output_name, block=True) -> None:
222222
"""Visualize the SHAP explanation of each output as a set of candlestick plots,
223223
one per output."""
224224
with mpl.rc_context(drcp):
@@ -272,7 +272,7 @@ def _matplotlib_plot(self, output_name) -> None:
272272
plt.ylabel(self.saliency_map()[output_name].getOutput().getName())
273273
plt.xlabel("Feature SHAP Value")
274274
plt.title(f"Explanation of {output_name}")
275-
plt.show()
275+
plt.show(block=block)
276276

277277
def _get_bokeh_plot(self, output_name):
278278
fnull = self.get_fnull()[output_name]

tests/general/test_counterfactualexplainer.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# pylint: disable=import-error, wrong-import-position, wrong-import-order, R0801
22
"""Test suite for counterfactual explanations"""
3+
import pytest
34

45
from common import *
56

@@ -92,7 +93,7 @@ def test_counterfactual_match_python_model():
9293
rel=3)
9394

9495

95-
def test_counterfactual_plot():
96+
def counterfactual_plot(block):
9697
"""Test if there's a valid counterfactual with a Python model"""
9798
GOAL_VALUE = 1000
9899
goal = np.array([[GOAL_VALUE]])
@@ -110,7 +111,16 @@ def test_counterfactual_plot():
110111
goal=goal,
111112
model=model)
112113

113-
result.plot()
114+
result.plot(block=block)
115+
116+
117+
@pytest.mark.block_plots
118+
def test_counterfactual_plot_blocking():
119+
counterfactual_plot(True)
120+
121+
122+
def test_counterfactual_plot():
123+
counterfactual_plot(False)
114124

115125

116126
def test_counterfactual_v2():

tests/general/test_limeexplainer.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def test_normalized_weights():
9191
assert -3.0 < feature_importance.getScore() < 3.0
9292

9393

94-
def test_lime_plots():
94+
def lime_plots(block):
9595
"""Test normalized weights"""
9696
lime_explainer = LimeExplainer(normalise_weights=False, perturbations=2, samples=10)
9797
n_features = 15
@@ -100,10 +100,19 @@ def test_lime_plots():
100100
outputs = model.predict([features])[0].outputs
101101

102102
explanation = lime_explainer.explain(inputs=features, outputs=outputs, model=model)
103-
explanation.plot()
104-
explanation.plot(render_bokeh=True)
105-
explanation.plot(output_name="sum-but0")
106-
explanation.plot(output_name="sum-but0", render_bokeh=True)
103+
explanation.plot(block=block)
104+
explanation.plot(block=block, render_bokeh=True)
105+
explanation.plot(block=block, output_name="sum-but0")
106+
explanation.plot(block=block, output_name="sum-but0", render_bokeh=True)
107+
108+
109+
@pytest.mark.block_plots
110+
def test_lime_plots_blocking():
111+
lime_plots(True)
112+
113+
114+
def test_lime_plots():
115+
lime_plots(False)
107116

108117

109118
def test_lime_v2():

tests/general/test_shap.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
np.random.seed(0)
1010

1111
import pytest
12-
1312
from trustyai.explainers import SHAPExplainer
1413
from trustyai.model import feature, Model
1514
from trustyai.utils.data_conversions import numpy_to_prediction_object
@@ -51,7 +50,7 @@ def test_shap_arrow():
5150
assert answers[i] - 1e-2 <= feature_importance.getScore() <= answers[i] + 1e-2
5251

5352

54-
def test_shap_plots():
53+
def shap_plots(block):
5554
"""Test SHAP plots"""
5655
np.random.seed(0)
5756
data = pd.DataFrame(np.random.rand(101, 5))
@@ -64,10 +63,19 @@ def test_shap_plots():
6463
shap_explainer = SHAPExplainer(background=background)
6564
explanation = shap_explainer.explain(inputs=to_explain, outputs=model(to_explain), model=model)
6665

67-
explanation.plot()
68-
explanation.plot(render_bokeh=True)
69-
explanation.plot(output_name='output-0')
70-
explanation.plot(output_name='output-0', render_bokeh=True)
66+
explanation.plot(block=block)
67+
explanation.plot(block=block, render_bokeh=True)
68+
explanation.plot(block=block, output_name='output-0')
69+
explanation.plot(block=block, output_name='output-0', render_bokeh=True)
70+
71+
72+
@pytest.mark.block_plots
73+
def test_shap_plots_blocking():
74+
shap_plots(block=True)
75+
76+
77+
def test_shap_plots():
78+
shap_plots(block=False)
7179

7280

7381
def test_shap_as_df():

0 commit comments

Comments
 (0)