Skip to content

Commit 1563850

Browse files
authored
FAI-886: Unified input/output types, conversion functions, and docstrings (#116)
* unified input/output types, conversion functions, and docstrings * fixed many-x conversions for 1d arrays, added more test cases * fixing input -> output typo * fixed randozm -> random typo * added feature domains into data conversions * fixed missing feature_domains != none check in many_inputs conversion * linting * Removed double simple pred import
1 parent 1b4329c commit 1563850

File tree

12 files changed

+681
-384
lines changed

12 files changed

+681
-384
lines changed

docs/api.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,6 @@ Data Objects
3535
.. autosummary::
3636
:toctree: generated/
3737

38-
simple_prediction
39-
counterfactual_prediction
4038
Dataset
4139

4240
Model Classes

src/trustyai/explainers/counterfactuals.py

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
"""Explainers.countefactual module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument
4-
from typing import Optional, List, Union
4+
from typing import Optional
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
import pandas as pd
8-
import numpy as np
98
import uuid as _uuid
109

1110
from trustyai import _default_initializer # pylint: disable=unused-import
@@ -17,10 +16,17 @@
1716

1817
from trustyai.model import (
1918
counterfactual_prediction,
20-
Dataset,
2119
PredictionInput,
2220
)
2321

22+
from trustyai.utils.data_conversions import (
23+
prediction_object_to_numpy,
24+
prediction_object_to_pandas,
25+
OneInputUnionType,
26+
OneOutputUnionType,
27+
data_conversion_docstring,
28+
)
29+
2430
from org.kie.trustyai.explainability.local.counterfactual import (
2531
CounterfactualExplainer as _CounterfactualExplainer,
2632
CounterfactualResult as _CounterfactualResult,
@@ -29,9 +35,6 @@
2935
)
3036
from org.kie.trustyai.explainability.model import (
3137
DataDistribution,
32-
Feature,
33-
Output,
34-
PredictionOutput,
3538
PredictionProvider,
3639
)
3740
from org.optaplanner.core.config.solver.termination import TerminationConfig
@@ -57,7 +60,7 @@ def proposed_features_array(self):
5760
"""Return the proposed feature values found from the counterfactual explanation
5861
as a Numpy array.
5962
"""
60-
return Dataset.prediction_object_to_numpy(
63+
return prediction_object_to_numpy(
6164
[PredictionInput([entity.as_feature() for entity in self._result.entities])]
6265
)
6366

@@ -66,7 +69,7 @@ def proposed_features_dataframe(self):
6669
"""Return the proposed feature values found from the counterfactual explanation
6770
as a Pandas DataFrame.
6871
"""
69-
return Dataset.prediction_object_to_pandas(
72+
return prediction_object_to_pandas(
7073
[PredictionInput([entity.as_feature() for entity in self._result.entities])]
7174
)
7275

@@ -171,10 +174,11 @@ def __init__(self, steps=10_000):
171174
self._explainer = _CounterfactualExplainer(self._cf_config)
172175

173176
# pylint: disable=too-many-arguments
177+
@data_conversion_docstring("one_input", "one_output")
174178
def explain(
175179
self,
176-
inputs: Union[np.ndarray, pd.DataFrame, List[Feature], PredictionInput],
177-
goal: Union[np.ndarray, pd.DataFrame, List[Output], PredictionOutput],
180+
inputs: OneInputUnionType,
181+
goal: OneOutputUnionType,
178182
model: PredictionProvider,
179183
data_distribution: Optional[DataDistribution] = None,
180184
uuid: Optional[_uuid.UUID] = None,
@@ -185,27 +189,14 @@ def explain(
185189
186190
Parameters
187191
----------
188-
inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
189-
List of input features, as a:
190-
191-
* Numpy array of shape ``[1, n_features]``
192-
* Pandas DataFrame with 1 row and ``n_features`` columns
193-
* A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
194-
* A TrustyAI :class:`PredictionInput`
195-
196-
goal : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
192+
inputs : {}
193+
List of input features, as a: {}
194+
goal : {}
197195
The desired model outputs to be searched for in the counterfactual explanation.
198-
These can take the form of a:
199-
200-
* Numpy array of shape ``[1, n_outputs]``
201-
* Pandas DataFrame with 1 row and ``n_outputs`` columns
202-
* A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
203-
* A TrustyAI :class:`PredictionOutput`
204-
196+
These can take the form of a: {}
205197
model : :obj:`~trustyai.model.PredictionProvider`
206198
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or
207199
:class:`~trustyai.model.ArrowModel`.
208-
209200
data_distribution : Optional[:class:`DataDistribution`]
210201
The :class:`DataDistribution` to use when sampling the inputs.
211202
uuid : Optional[:class:`_uuid.UUID`]

src/trustyai/explainers/lime.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,14 @@
1919
output_html,
2020
feature_html,
2121
)
22+
23+
from trustyai.utils.data_conversions import (
24+
OneInputUnionType,
25+
data_conversion_docstring,
26+
OneOutputUnionType,
27+
)
2228
from .explanation_results import SaliencyResults
23-
from trustyai.model import simple_prediction, PredUnionType
29+
from trustyai.model import simple_prediction
2430

2531
from org.kie.trustyai.explainability.local.lime import (
2632
LimeConfig as _LimeConfig,
@@ -230,6 +236,7 @@ def __init__(self, samples=10, **kwargs):
230236
process.
231237
* trackCounterfactuals : bool
232238
(default=False) Keep track of produced byproduct counterfactuals during LIME run.
239+
233240
"""
234241
self._jrandom = Random()
235242
self._jrandom.setSeed(kwargs.get("seed", 0))
@@ -250,30 +257,22 @@ def __init__(self, samples=10, **kwargs):
250257

251258
self._explainer = _LimeExplainer(self._lime_config)
252259

260+
@data_conversion_docstring("one_input", "one_output")
253261
def explain(
254-
self, inputs: PredUnionType, outputs: PredUnionType, model: PredictionProvider
262+
self,
263+
inputs: OneInputUnionType,
264+
outputs: OneOutputUnionType,
265+
model: PredictionProvider,
255266
) -> LimeResults:
256267
"""Produce a LIME explanation.
257268
258269
Parameters
259270
----------
260-
inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
261-
The input features to the model, as a:
262-
263-
* Numpy array of shape ``[1, n_features]``
264-
* Pandas DataFrame with 1 row and ``n_features`` columns
265-
* A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
266-
* A TrustyAI :class:`PredictionInput`
267-
268-
outputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
271+
inputs : {}
272+
The input features to the model, as a: {}
273+
outputs : {}
269274
The corresponding model outputs for the provided features, that is,
270-
``outputs = model(input_features)``. These can take the form of a:
271-
272-
* Numpy array of shape ``[1, n_outputs]``
273-
* Pandas DataFrame with 1 row and ``n_outputs`` columns
274-
* A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
275-
* A TrustyAI :class:`PredictionOutput`
276-
275+
``outputs = model(input_features)``. These can take the form of a: {}
277276
model : :obj:`~trustyai.model.PredictionProvider`
278277
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model`
279278
or :class:`~trustyai.model.ArrowModel`.

src/trustyai/explainers/shap.py

Lines changed: 24 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Explainers.shap module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument, consider-using-f-string, invalid-name
4-
from typing import Dict, Optional, List, Union
4+
from typing import Dict, Optional
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
from bokeh.models import ColumnDataSource, HoverTool
@@ -21,13 +21,15 @@
2121
output_html,
2222
feature_html,
2323
)
24-
2524
from trustyai.model import (
26-
feature,
27-
Dataset,
28-
PredictionInput,
2925
simple_prediction,
30-
PredUnionType,
26+
)
27+
from trustyai.utils.data_conversions import (
28+
OneInputUnionType,
29+
OneOutputUnionType,
30+
ManyInputsUnionType,
31+
many_inputs_convert,
32+
data_conversion_docstring,
3133
)
3234

3335
from org.kie.trustyai.explainability.local.shap import (
@@ -434,20 +436,19 @@ class SHAPExplainer:
434436
the outputs, as compared to the background inputs?*
435437
"""
436438

439+
@data_conversion_docstring("many_inputs")
437440
def __init__(
438441
self,
439-
background: Union[np.ndarray, pd.DataFrame, List[PredictionInput]],
442+
background: ManyInputsUnionType,
440443
link_type: Optional[_ShapConfig.LinkType] = None,
441444
**kwargs,
442445
):
443446
r"""Initialize the :class:`SHAPxplainer`.
444447
445448
Parameters
446449
----------
447-
background : :class:`numpy.array`, :class:`Pandas.DataFrame`
448-
or List[:class:`PredictionInput]
449-
The set of background datapoints as an array, dataframe of shape
450-
``[n_datapoints, n_features]``, or list of TrustyAI PredictionInputs.
450+
background : {}
451+
The set of background datapoints as a: {}
451452
link_type : :obj:`~_ShapConfig.LinkType`
452453
A choice of either ``trustyai.explainers._ShapConfig.LinkType.IDENTITY``
453454
or ``trustyai.explainers._ShapConfig.LinkType.LOGIT``. If the model output is a
@@ -464,10 +465,11 @@ def __init__(
464465
(default=20) The number of batches passed to the PredictionProvider at once.
465466
When uusing :class:`~Model` with `arrow=False` this parameter has no effect.
466467
If `arrow=True`, `batch_sizes` of around
467-
:math:`\frac{2000}{\mathtt{len(background)}}` can produce significant
468+
:math:`\frac{{2000}}{{\mathtt{{len(background)}}}}` can produce significant
468469
performance gains.
469470
* trackCounterfactuals : bool
470471
(default=False) Keep track of produced byproduct counterfactuals during SHAP run.
472+
471473
Returns
472474
-------
473475
:class:`~SHAPResults`
@@ -477,19 +479,9 @@ def __init__(
477479
link_type = _ShapConfig.LinkType.IDENTITY
478480
self._jrandom = Random()
479481
self._jrandom.setSeed(kwargs.get("seed", 0))
482+
self.background = many_inputs_convert(background)
480483
perturbation_context = PerturbationContext(self._jrandom, 0)
481484

482-
if isinstance(background, np.ndarray):
483-
self.background = Dataset.numpy_to_prediction_object(background, feature)
484-
elif isinstance(background, pd.DataFrame):
485-
self.background = Dataset.df_to_prediction_object(background, feature)
486-
elif isinstance(background[0], PredictionInput):
487-
self.background = background
488-
else:
489-
raise AttributeError(
490-
"Unsupported background type: {}".format(type(background))
491-
)
492-
493485
self._configbuilder = (
494486
_ShapConfig.builder()
495487
.withLink(link_type)
@@ -503,32 +495,22 @@ def __init__(
503495
self._config = self._configbuilder.build()
504496
self._explainer = _ShapKernelExplainer(self._config)
505497

498+
@data_conversion_docstring("one_input", "one_output")
506499
def explain(
507-
self, inputs: PredUnionType, outputs: PredUnionType, model: PredictionProvider
500+
self,
501+
inputs: OneInputUnionType,
502+
outputs: OneOutputUnionType,
503+
model: PredictionProvider,
508504
) -> SHAPResults:
509505
"""Produce a SHAP explanation.
510506
511507
Parameters
512508
----------
513-
inputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Feature`], or :class:`PredictionInput`
514-
The input features to the model, as a:
515-
516-
* Numpy array of shape ``[1, n_features]``
517-
* Pandas DataFrame with 1 row and ``n_features`` columns
518-
* A List of TrustyAI :class:`Feature`, as created by the :func:`~feature` function
519-
* A TrustyAI :class:`PredictionInput`
520-
521-
outputs : :class:`numpy.ndarray`, :class:`pandas.DataFrame`, List[:class:`Output`], or :class:`PredictionOutput`
509+
inputs : {}
510+
The input features to the model, as a: {}
511+
outputs : {}
522512
The corresponding model outputs for the provided features, that is,
523-
``outputs = model(input_features)``. These can take the form of a:
524-
525-
* Numpy array of shape ``[1, n_outputs]``
526-
* Pandas DataFrame with 1 row and ``n_outputs`` columns
527-
* A List of TrustyAI :class:`Output`, as created by the :func:`~output` function
528-
* A TrustyAI :class:`PredictionOutput`
529-
model : :obj:`~trustyai.model.PredictionProvider`
530-
The TrustyAI PredictionProvider, as generated by :class:`~trustyai.model.Model` or
531-
:class:`~trustyai.model.ArrowModel`.
513+
``outputs = model(input_features)``. These can take the form of a: {}
532514
533515
Returns
534516
-------

0 commit comments

Comments
 (0)