1
1
"""Explainers.countefactual module"""
2
2
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
3
3
# pylint: disable = unused-argument
4
- from typing import Optional , Union
4
+ from typing import Optional , Union , List
5
5
import matplotlib .pyplot as plt
6
6
import matplotlib as mpl
7
7
import pandas as pd
20
20
Model ,
21
21
)
22
22
23
+
23
24
from trustyai .utils .data_conversions import (
24
25
prediction_object_to_numpy ,
25
26
prediction_object_to_pandas ,
26
27
OneInputUnionType ,
27
28
OneOutputUnionType ,
28
29
data_conversion_docstring ,
30
+ one_input_convert ,
29
31
)
30
32
31
33
from org .kie .trustyai .explainability .local .counterfactual import (
38
40
DataDistribution ,
39
41
PredictionProvider ,
40
42
)
43
+
44
+ from org .kie .trustyai .explainability .model .domain import FeatureDomain
45
+
41
46
from org .optaplanner .core .config .solver .termination import TerminationConfig
42
47
from java .lang import Long
43
48
@@ -181,11 +186,12 @@ def explain(
181
186
inputs : OneInputUnionType ,
182
187
goal : OneOutputUnionType ,
183
188
model : Union [PredictionProvider , Model ],
189
+ feature_domains : List [FeatureDomain ] = None ,
184
190
data_distribution : Optional [DataDistribution ] = None ,
185
191
uuid : Optional [_uuid .UUID ] = None ,
186
192
timeout : Optional [float ] = None ,
187
193
) -> CounterfactualResult :
188
- """Request for a counterfactual explanation given a list of features, goals and a
194
+ r """Request for a counterfactual explanation given a list of features, goals and a
189
195
:class:`~PredictionProvider`
190
196
191
197
Parameters
@@ -197,6 +203,15 @@ def explain(
197
203
These can take the form of a: {}
198
204
model : :obj:`~trustyai.model.PredictionProvider`
199
205
The TrustyAI model as generated by :class:`~trustyai.model.Model` or a Java :class:`PredictionProvider`
206
+ feature_domains : List[:class:`FeatureDomain`]
207
+ A list of feature domains (each created by :func:`~trustyai.model.feature_domain()`)
208
+ that define the valid domain of the input features. The ith element of the list defines
209
+ the domain of the ith input feature. If the ith element of this list is ``None``, the
210
+ no domain information will be added to the ith feature. If the ith feature had no
211
+ previously-supplied domain information, it will be taken to be constrained and
212
+ non-variable. If ``feature_domains=None``, no domain information will be added to any
213
+ of the features, thus preserving existing domains if they've been manually added
214
+ previously or holding undomained features constrained.
200
215
data_distribution : Optional[:class:`DataDistribution`]
201
216
The :class:`DataDistribution` to use when sampling the inputs.
202
217
uuid : Optional[:class:`_uuid.UUID`]
@@ -210,7 +225,7 @@ def explain(
210
225
Object containing the results of the counterfactual explanation.
211
226
"""
212
227
_prediction = counterfactual_prediction (
213
- input_features = inputs ,
228
+ input_features = one_input_convert ( inputs , feature_domains = feature_domains ) ,
214
229
outputs = goal ,
215
230
data_distribution = data_distribution ,
216
231
uuid = uuid ,
0 commit comments