Skip to content

Commit b1f6f64

Browse files
authored
Added arrow and non-arrow contexts to automatically switch transmission based on usage (#119)
1 parent 5b25bee commit b1f6f64

File tree

9 files changed

+164
-89
lines changed

9 files changed

+164
-89
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Explainers.countefactual module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument
4-
from typing import Optional
4+
from typing import Optional, Union
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
import pandas as pd
@@ -17,6 +17,7 @@
1717
from trustyai.model import (
1818
counterfactual_prediction,
1919
PredictionInput,
20+
Model,
2021
)
2122

2223
from trustyai.utils.data_conversions import (
@@ -179,7 +180,7 @@ def explain(
179180
self,
180181
inputs: OneInputUnionType,
181182
goal: OneOutputUnionType,
182-
model: PredictionProvider,
183+
model: Union[PredictionProvider, Model],
183184
data_distribution: Optional[DataDistribution] = None,
184185
uuid: Optional[_uuid.UUID] = None,
185186
timeout: Optional[float] = None,
@@ -215,6 +216,8 @@ def explain(
215216
uuid=uuid,
216217
timeout=timeout,
217218
)
218-
return CounterfactualResult(
219-
self._explainer.explainAsync(_prediction, model).get()
220-
)
219+
220+
with Model.NonArrowTransmission(model):
221+
return CounterfactualResult(
222+
self._explainer.explainAsync(_prediction, model).get()
223+
)

src/trustyai/explainers/lime.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Explainers.lime module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument, duplicate-code, consider-using-f-string, invalid-name
4-
from typing import Dict
4+
from typing import Dict, Union
55

66
import bokeh.models
77
import matplotlib.pyplot as plt
@@ -27,7 +27,7 @@
2727
)
2828

2929
from .explanation_results import SaliencyResults
30-
from trustyai.model import simple_prediction
30+
from trustyai.model import simple_prediction, Model
3131

3232
from org.kie.trustyai.explainability.local.lime import (
3333
LimeConfig as _LimeConfig,
@@ -42,6 +42,7 @@
4242

4343
from java.util import Random
4444

45+
4546
LimeConfig = _LimeConfig
4647

4748

@@ -263,7 +264,7 @@ def explain(
263264
self,
264265
inputs: OneInputUnionType,
265266
outputs: OneOutputUnionType,
266-
model: PredictionProvider,
267+
model: Union[PredictionProvider, Model],
267268
) -> LimeResults:
268269
"""Produce a LIME explanation.
269270
@@ -284,4 +285,6 @@ def explain(
284285
Object containing the results of the LIME explanation.
285286
"""
286287
_prediction = simple_prediction(inputs, outputs)
287-
return LimeResults(self._explainer.explainAsync(_prediction, model).get())
288+
289+
with Model.ArrowTransmission(model, inputs):
290+
return LimeResults(self._explainer.explainAsync(_prediction, model).get())

src/trustyai/explainers/shap.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Explainers.shap module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument, consider-using-f-string, invalid-name
4-
from typing import Dict, Optional
4+
from typing import Dict, Optional, Union
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
from bokeh.models import ColumnDataSource, HoverTool
@@ -21,9 +21,7 @@
2121
output_html,
2222
feature_html,
2323
)
24-
from trustyai.model import (
25-
simple_prediction,
26-
)
24+
from trustyai.model import simple_prediction, Model
2725
from trustyai.utils.data_conversions import (
2826
OneInputUnionType,
2927
OneOutputUnionType,
@@ -54,6 +52,8 @@
5452

5553

5654
# pylint: disable=invalid-name
55+
56+
5757
class SHAPResults(SaliencyResults):
5858
"""Wraps SHAP results. This object is returned by the :class:`~SHAPExplainer`,
5959
and provides a variety of methods to visualize and interact with the explanation.
@@ -654,7 +654,7 @@ def explain(
654654
self,
655655
inputs: OneInputUnionType,
656656
outputs: OneOutputUnionType,
657-
model: PredictionProvider,
657+
model: Union[PredictionProvider, Model],
658658
) -> SHAPResults:
659659
"""Produce a SHAP explanation.
660660
@@ -674,6 +674,8 @@ def explain(
674674
Object containing the results of the SHAP explanation.
675675
"""
676676
_prediction = simple_prediction(inputs, outputs)
677-
return SHAPResults(
678-
self._explainer.explainAsync(_prediction, model).get(), self.background
679-
)
677+
678+
with Model.ArrowTransmission(model, inputs):
679+
return SHAPResults(
680+
self._explainer.explainAsync(_prediction, model).get(), self.background
681+
)

src/trustyai/model/__init__.py

Lines changed: 109 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -295,7 +295,7 @@ class Model:
295295
"""
296296

297297
def __init__(
298-
self, predict_fun, dataframe_input=False, output_names=None, arrow=False
298+
self, predict_fun, dataframe_input=False, output_names=None, disable_arrow=False
299299
):
300300
"""
301301
Wrap the model as a TrustyAI :obj:`PredictionProvider` Java class.
@@ -311,39 +311,75 @@ def __init__(
311311
output_names : List[String]:
312312
If the model outputs a numpy array, you can specify the names of the model outputs
313313
here.
314-
arrow: bool
315-
Whether to use Apache arrow to speed up data transfer between Java and Python.
316-
In general, set this to ``true`` whenever LIME or SHAP explanations are needed,
317-
and ``false`` for counterfactuals.
314+
disable_arrow: bool
315+
If true, Apache Arrow will not be used to accelerate data transfer between Java
316+
and Python. If false, Arrow will be automatically used in situations where it is
317+
advantageous to do so.
318318
"""
319-
self.arrow = arrow
319+
self.disable_arrow = disable_arrow
320320
self.predict_fun = predict_fun
321321
self.output_names = output_names
322+
self.dataframe_input = dataframe_input
322323

323-
if arrow:
324-
self.prediction_provider = None
325-
if not dataframe_input:
326-
self.prediction_provider_arrow = PredictionProviderArrow(
327-
lambda x: self._cast_outputs_to_dataframe(predict_fun(x.values))
328-
)
329-
else:
330-
self.prediction_provider_arrow = PredictionProviderArrow(
331-
lambda x: self._cast_outputs_to_dataframe(predict_fun(x))
324+
self.prediction_provider_arrow = None
325+
self.prediction_provider_normal = None
326+
self.prediction_provider = None
327+
328+
# set model to use non-arrow by default, as this requires no dataset information
329+
self._set_nonarrow()
330+
331+
def _set_arrow(self, paradigm_input: PredictionInput):
332+
"""
333+
Ready the model for arrow-based prediction communication.
334+
335+
Parameters
336+
----------
337+
paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
338+
All subsequent :obj:`PredictionInput`s communicated must have this schema.
339+
"""
340+
if self.disable_arrow:
341+
self._set_nonarrow()
342+
else:
343+
if self.prediction_provider_arrow is None:
344+
raw_ppa = self._get_arrow_prediction_provider()
345+
self.prediction_provider_arrow = raw_ppa.get_as_prediction_provider(
346+
paradigm_input
332347
)
348+
self.prediction_provider = self.prediction_provider_arrow
349+
350+
def _set_nonarrow(self):
351+
"""
352+
Ready the model for non-arrow-prediction communication.
353+
"""
354+
if self.prediction_provider_normal is None:
355+
self.prediction_provider_normal = self._get_nonarrow_prediction_provider()
356+
self.prediction_provider = self.prediction_provider_normal
357+
358+
def _get_arrow_prediction_provider(self):
359+
if not self.dataframe_input:
360+
ppa = PredictionProviderArrow(
361+
lambda x: self._cast_outputs_to_dataframe(self.predict_fun(x.values))
362+
)
333363
else:
334-
self.prediction_provider_arrow = None
335-
if dataframe_input:
336-
self.prediction_provider = PredictionProvider(
337-
lambda x: self._cast_outputs(
338-
predict_fun(prediction_object_to_pandas(x))
339-
)
364+
ppa = PredictionProviderArrow(
365+
lambda x: self._cast_outputs_to_dataframe(self.predict_fun(x))
366+
)
367+
return ppa
368+
369+
def _get_nonarrow_prediction_provider(self):
370+
if self.dataframe_input:
371+
ppn = PredictionProvider(
372+
lambda x: self._cast_outputs(
373+
self.predict_fun(prediction_object_to_pandas(x))
340374
)
341-
else:
342-
self.prediction_provider = PredictionProvider(
343-
lambda x: self._cast_outputs(
344-
predict_fun(prediction_object_to_numpy(x))
345-
)
375+
)
376+
else:
377+
ppn = PredictionProvider(
378+
lambda x: self._cast_outputs(
379+
self.predict_fun(prediction_object_to_numpy(x))
346380
)
381+
)
382+
return ppn
347383

348384
def _cast_outputs(self, output_array):
349385
return df_to_prediction_object(
@@ -388,12 +424,8 @@ def predictAsync(self, inputs: List[PredictionInput]) -> CompletableFuture:
388424
:obj:`CompletableFuture`
389425
A Java :obj:`CompletableFuture` containing the model outputs.
390426
"""
391-
if self.arrow and self.prediction_provider is None:
392-
self.prediction_provider = (
393-
self.prediction_provider_arrow.get_as_prediction_provider(inputs[0])
394-
)
395-
out = self.prediction_provider.predictAsync(inputs)
396-
return out
427+
428+
return self.prediction_provider.predictAsync(inputs)
397429

398430
def __call__(self, inputs):
399431
"""
@@ -405,6 +437,51 @@ def __call__(self, inputs):
405437
"""
406438
return self.predict_fun(inputs)
407439

440+
class ArrowTransmission:
441+
"""
442+
Context class to ensure all predictAsync calls within the context use arrow.
443+
444+
Parameters
445+
----------
446+
model: The TrustyAI :obj:`Model` or PredictionProvider
447+
paradigm_input: A single :obj:`PredictionInput` by which to establish the arrow schema.
448+
All subsequent :obj:`PredictionInput`s communicated must have this schema.
449+
"""
450+
451+
def __init__(self, model, paradigm_input: OneInputUnionType):
452+
self.model = model
453+
self.model_is_python = isinstance(model, Model)
454+
self.paradigm_input = one_input_convert(paradigm_input)
455+
self.previous_model_state = None
456+
457+
def __enter__(self):
458+
if self.model_is_python:
459+
self.previous_model_state = self.model.prediction_provider
460+
self.model._set_arrow(self.paradigm_input)
461+
462+
def __exit__(self, exit_type, value, traceback):
463+
if self.model_is_python:
464+
self.model.prediction_provider = self.previous_model_state
465+
466+
class NonArrowTransmission:
467+
"""
468+
Context class to ensure all predictAsync calls within the context DO NOT use arrow.
469+
"""
470+
471+
def __init__(self, model):
472+
self.model = model
473+
self.model_is_python = isinstance(model, Model)
474+
self.previous_model_state = None
475+
476+
def __enter__(self):
477+
if self.model_is_python:
478+
self.previous_model_state = self.model.prediction_provider
479+
self.model._set_nonarrow()
480+
481+
def __exit__(self, exit_type, value, traceback):
482+
if self.model_is_python:
483+
self.model.prediction_provider = self.previous_model_state
484+
408485

409486
@_jcustomizer.JImplementationFor("org.kie.trustyai.explainability.model.Output")
410487
# pylint: disable=no-member

tests/general/test_limeexplainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_lime_v2():
112112
model_weights = np.random.rand(5)
113113
predict_function = lambda x: np.dot(x.values, model_weights)
114114

115-
model = Model(predict_function, dataframe_input=True, arrow=True)
115+
model = Model(predict_function, dataframe_input=True)
116116
explainer = LimeExplainer(samples=100, perturbations=2, seed=23, normalise_weights=False)
117117
explanation = explainer.explain(inputs=data, outputs=model(data), model=model)
118118
for score in explanation.as_dataframe()["output-0_score"]:

tests/general/test_model.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
"""Test model provider interface"""
33

44
from common import *
5-
from trustyai.model import Model, feature
6-
from trustyai.utils.data_conversions import numpy_to_prediction_object
5+
from trustyai.model import Model, Dataset, feature
76

87
import pytest
98

9+
from trustyai.utils.data_conversions import numpy_to_prediction_object
10+
1011

1112
def test_basic_model():
1213
"""Test basic model"""
@@ -18,41 +19,30 @@ def test_basic_model():
1819

1920

2021
def test_cast_output():
21-
np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'])
22-
np2df = Model(lambda x: pd.DataFrame(x))
23-
df2np = Model(lambda x: x.sum(1).values, dataframe_input=True, output_names=['sum'])
24-
df2df = Model(lambda x: x, dataframe_input=True)
25-
pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature)
26-
27-
output_val = np2np.predictAsync(pis).get()
28-
assert len(output_val) == 25
22+
np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'], disable_arrow=True)
23+
np2df = Model(lambda x: pd.DataFrame(x), disable_arrow=True)
24+
df2np = Model(lambda x: x.sum(1).values,
25+
dataframe_input=True,
26+
output_names=['sum'],
27+
disable_arrow=True)
28+
df2df = Model(lambda x: x, dataframe_input=True, disable_arrow=True)
2929

30-
output_val = np2df.predictAsync(pis).get()
31-
assert len(output_val) == 25
32-
33-
output_val = df2np.predictAsync(pis).get()
34-
assert len(output_val) == 25
30+
pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature)
3531

36-
output_val = df2df.predictAsync(pis).get()
37-
assert len(output_val) == 25
32+
for m in [np2np, np2df, df2df, df2np]:
33+
output_val = m.predictAsync(pis).get()
34+
assert len(output_val) == 25
3835

3936

4037
def test_cast_output_arrow():
41-
np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'], arrow=True)
42-
np2df = Model(lambda x: pd.DataFrame(x), arrow=True)
43-
df2np = Model(lambda x: x.sum(1).values, dataframe_input=True, output_names=['sum'], arrow=True)
44-
df2df = Model(lambda x: x, dataframe_input=True, arrow=True)
38+
np2np = Model(lambda x: np.sum(x, 1), output_names=['sum'])
39+
np2df = Model(lambda x: pd.DataFrame(x))
40+
df2np = Model(lambda x: x.sum(1).values, dataframe_input=True, output_names=['sum'])
41+
df2df = Model(lambda x: x, dataframe_input=True)
4542
pis = numpy_to_prediction_object(np.arange(0., 125.).reshape(25, 5), feature)
4643

47-
output_val = np2np.predictAsync(pis).get()
48-
assert len(output_val) == 25
49-
50-
output_val = np2df.predictAsync(pis).get()
51-
assert len(output_val) == 25
52-
53-
output_val = df2np.predictAsync(pis).get()
54-
assert len(output_val) == 25
55-
56-
output_val = df2df.predictAsync(pis).get()
57-
assert len(output_val) == 25
44+
for m in [np2np, np2df, df2df, df2np]:
45+
m._set_arrow(pis[0])
46+
output_val = m.predictAsync(pis).get()
47+
assert len(output_val) == 25
5848

0 commit comments

Comments
 (0)