Skip to content

Commit 8514698

Browse files
authored
FAI-900: Added feature domain argument to counterfactuals (#128)
* Added feature domain argument to cfs * fixed incorrect imports in cf test
1 parent e011e67 commit 8514698

File tree

3 files changed

+79
-9
lines changed

3 files changed

+79
-9
lines changed

src/trustyai/explainers/counterfactuals.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Explainers.countefactual module"""
22
# pylint: disable = import-error, too-few-public-methods, wrong-import-order, line-too-long,
33
# pylint: disable = unused-argument
4-
from typing import Optional, Union
4+
from typing import Optional, Union, List
55
import matplotlib.pyplot as plt
66
import matplotlib as mpl
77
import pandas as pd
@@ -20,12 +20,14 @@
2020
Model,
2121
)
2222

23+
2324
from trustyai.utils.data_conversions import (
2425
prediction_object_to_numpy,
2526
prediction_object_to_pandas,
2627
OneInputUnionType,
2728
OneOutputUnionType,
2829
data_conversion_docstring,
30+
one_input_convert,
2931
)
3032

3133
from org.kie.trustyai.explainability.local.counterfactual import (
@@ -38,6 +40,9 @@
3840
DataDistribution,
3941
PredictionProvider,
4042
)
43+
44+
from org.kie.trustyai.explainability.model.domain import FeatureDomain
45+
4146
from org.optaplanner.core.config.solver.termination import TerminationConfig
4247
from java.lang import Long
4348

@@ -181,11 +186,12 @@ def explain(
181186
inputs: OneInputUnionType,
182187
goal: OneOutputUnionType,
183188
model: Union[PredictionProvider, Model],
189+
feature_domains: List[FeatureDomain] = None,
184190
data_distribution: Optional[DataDistribution] = None,
185191
uuid: Optional[_uuid.UUID] = None,
186192
timeout: Optional[float] = None,
187193
) -> CounterfactualResult:
188-
"""Request for a counterfactual explanation given a list of features, goals and a
194+
r"""Request for a counterfactual explanation given a list of features, goals and a
189195
:class:`~PredictionProvider`
190196
191197
Parameters
@@ -197,6 +203,15 @@ def explain(
197203
These can take the form of a: {}
198204
model : :obj:`~trustyai.model.PredictionProvider`
199205
The TrustyAI model as generated by :class:`~trustyai.model.Model` or a Java :class:`PredictionProvider`
206+
feature_domains : List[:class:`FeatureDomain`]
207+
A list of feature domains (each created by :func:`~trustyai.model.feature_domain()`)
208+
that define the valid domain of the input features. The ith element of the list defines
209+
the domain of the ith input feature. If the ith element of this list is ``None``, the
210+
no domain information will be added to the ith feature. If the ith feature had no
211+
previously-supplied domain information, it will be taken to be constrained and
212+
non-variable. If ``feature_domains=None``, no domain information will be added to any
213+
of the features, thus preserving existing domains if they've been manually added
214+
previously or holding undomained features constrained.
200215
data_distribution : Optional[:class:`DataDistribution`]
201216
The :class:`DataDistribution` to use when sampling the inputs.
202217
uuid : Optional[:class:`_uuid.UUID`]
@@ -210,7 +225,7 @@ def explain(
210225
Object containing the results of the counterfactual explanation.
211226
"""
212227
_prediction = counterfactual_prediction(
213-
input_features=inputs,
228+
input_features=one_input_convert(inputs, feature_domains=feature_domains),
214229
outputs=goal,
215230
data_distribution=data_distribution,
216231
uuid=uuid,

src/trustyai/utils/data_conversions.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
"""Data Converters between Python and Java"""
22
# pylint: disable = import-error, line-too-long, trailing-whitespace, unused-import, cyclic-import
33
# pylint: disable = consider-using-f-string, invalid-name, wrong-import-order
4-
4+
import warnings
55
from typing import Union, List
66

77
import trustyai.model
@@ -12,7 +12,10 @@
1212
PredictionInput,
1313
PredictionOutput,
1414
)
15-
from org.kie.trustyai.explainability.model.domain import FeatureDomain
15+
from org.kie.trustyai.explainability.model.domain import (
16+
FeatureDomain,
17+
EmptyFeatureDomain,
18+
)
1619

1720
import pandas as pd
1821
import numpy as np
@@ -146,17 +149,29 @@ def domain_insertion(
146149
):
147150
"""Given a PredictionInput and a corresponding list of feature domains, where
148151
`len(feature_domains) == len(PredictionInput.getFeatures()`, return a PredictionInput
149-
where the ith feature has the ith domain. If the ith domain is `None`, the feature
150-
is constrained."""
151-
assert len(undomained_input.getFeatures()) == len(feature_domains)
152+
where the ith feature has the ith domain. If the ith domain is `None`, no new domain
153+
information will be added to the feature, thus keeping previous domain information or
154+
keeping it fixed if none has been supplied"""
155+
assert len(undomained_input.getFeatures()) == len(
156+
feature_domains
157+
), "input has {} features, but {} feature domains were passed".format(
158+
len(undomained_input.getFeatures()), len(feature_domains)
159+
)
152160

153161
domained_features = []
154162
for i, f in enumerate(undomained_input.getFeatures()):
155163
if feature_domains[i] is None:
156164
domained_features.append(
157-
Feature(f.getName(), f.getType(), f.getValue(), True, None)
165+
Feature(f.getName(), f.getType(), f.getValue(), True, f.getDomain())
158166
)
159167
else:
168+
if not isinstance(f.getDomain(), EmptyFeatureDomain):
169+
warning_msg = (
170+
"The supplied feature domain at position {} is specifying a new "
171+
"domain to previously domain'ed {}, this will overwrite the "
172+
"previous domain with the new one.".format(i, f.toString())
173+
)
174+
warnings.warn(warning_msg)
160175
domained_features.append(
161176
Feature(
162177
f.getName(), f.getType(), f.getValue(), False, feature_domains[i]

tests/general/test_counterfactualexplainer.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
output, Model, feature,
1313
)
1414
from trustyai.utils import TestModels
15+
from trustyai.model.domain import feature_domain
16+
from trustyai.utils.data_conversions import one_input_convert
1517

1618
jrandom = Random()
1719
jrandom.setSeed(0)
@@ -140,3 +142,41 @@ def test_counterfactual_v2():
140142
result_output = model(explanation.proposed_features_dataframe)
141143
assert result_output < .01
142144
assert result_output > -.01
145+
146+
147+
def test_counterfactual_with_domain_argument():
148+
"""Test passing domains to counterfactuals"""
149+
np.random.seed(0)
150+
data = np.random.rand(1, 5)
151+
model_weights = np.random.rand(5)
152+
model = Model(lambda x: np.dot(x, model_weights))
153+
explainer = CounterfactualExplainer(steps=10_000)
154+
explanation = explainer.explain(
155+
inputs=data,
156+
goal=np.array([0]),
157+
feature_domains=[feature_domain((-10, 10)) for _ in range(5)],
158+
model=model)
159+
result_output = model(explanation.proposed_features_dataframe)
160+
assert result_output < .01
161+
assert result_output > -.01
162+
163+
164+
def test_counterfactual_with_domain_argument_overwrite():
165+
"""Test that passing domains to counterfactuals with already-domained features throws
166+
a warning"""
167+
np.random.seed(0)
168+
data = np.random.rand(1, 5)
169+
domained_inputs = one_input_convert(data, [feature_domain((-10, 10)) for _ in range(5)])
170+
model_weights = np.random.rand(5)
171+
model = Model(lambda x: np.dot(x, model_weights))
172+
explainer = CounterfactualExplainer(steps=10_000)
173+
174+
with pytest.warns(UserWarning):
175+
explainer.explain(
176+
inputs=domained_inputs,
177+
goal=np.array([0]),
178+
feature_domains=[feature_domain((-10, 10)) for _ in range(5)],
179+
model=model
180+
)
181+
182+

0 commit comments

Comments
 (0)