Skip to content

Commit fd67eee

Browse files
authored
FAI-889: Allow non-string categorical feature domains (#118)
* expanded feature domain flexibility * removed vestigial debugging print
1 parent 47a8ffb commit fd67eee

File tree

2 files changed

+66
-13
lines changed

2 files changed

+66
-13
lines changed

src/trustyai/model/domain.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,18 @@
33
from typing import Optional, Tuple, List, Union
44

55
from jpype import _jclass
6+
67
from org.kie.trustyai.explainability.model.domain import (
78
FeatureDomain,
89
NumericalFeatureDomain,
910
CategoricalFeatureDomain,
11+
CategoricalNumericalFeatureDomain,
12+
ObjectFeatureDomain,
1013
EmptyFeatureDomain,
1114
)
1215

1316

14-
def feature_domain(
15-
values: Optional[Union[Tuple, List[str]]]
16-
) -> Optional[FeatureDomain]:
17+
def feature_domain(values: Optional[Union[Tuple, List]]) -> Optional[FeatureDomain]:
1718
r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a
1819
particular feature, which is useful when constraining a counterfactual explanation to ensure it
1920
only recovers valid inputs. For example, if we had a feature that described a person's age, we
@@ -22,13 +23,18 @@ def feature_domain(
2223
2324
Parameters
2425
----------
25-
values : Optional[Union[Tuple, List[str]]]
26+
values : Optional[Union[Tuple, List]]
2627
The valid values of the feature. If `values` takes the form of:
2728
2829
* **A tuple of floats or integers:** The feature domain will be a continuous range from
2930
``values[0]`` to ``values[1]``.
30-
* **A list of strings:** The feature domain will be categorical, where `values` contains
31-
all possible valid feature values.
31+
* **A list of floats or integers:**: The feature domain will be a *numeric* categorical,
32+
where `values` contains all possible valid feature values.
33+
* **A list of strings:** The feature domain will be a *string* categorical, where `values`
34+
contains all possible valid feature values.
35+
* **A list of objects:** The feature domain will be an *object* categorical, where `values`
36+
contains all possible valid feature values. These may present an issue if the objects
37+
are not natively Java serializable.
3238
3339
Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held
3440
fixed during the counterfactual explanation.
@@ -43,12 +49,29 @@ def feature_domain(
4349
if not values:
4450
domain = EmptyFeatureDomain.create()
4551
else:
46-
if isinstance(values[0], (float, int)):
47-
domain = NumericalFeatureDomain.create(values[0], values[1])
48-
elif isinstance(values[0], str):
49-
domain = CategoricalFeatureDomain.create(
50-
_jclass.JClass("java.util.Arrays").asList(values)
52+
if isinstance(values, tuple):
53+
assert isinstance(values[0], (float, int)) and isinstance(
54+
values[1], (float, int)
55+
)
56+
assert len(values) == 2, (
57+
"Tuples passed as domain values must only contain"
58+
" two values that define the (minimum, maximum) of the domain"
5159
)
60+
domain = NumericalFeatureDomain.create(values[0], values[1])
61+
62+
elif isinstance(values, list):
63+
java_array = _jclass.JClass("java.util.Arrays").asList(values)
64+
if isinstance(values[0], bool) and isinstance(values[1], bool):
65+
domain = ObjectFeatureDomain.create(java_array)
66+
elif isinstance(values[0], (float, int)) and isinstance(
67+
values[1], (float, int)
68+
):
69+
domain = CategoricalNumericalFeatureDomain.create(java_array)
70+
elif isinstance(values[0], str):
71+
domain = CategoricalFeatureDomain.create(java_array)
72+
else:
73+
domain = ObjectFeatureDomain.create(java_array)
74+
5275
else:
5376
domain = EmptyFeatureDomain.create()
5477
return domain

tests/general/test_conversions.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,37 @@ def test_numeric_domain_tuple():
4343
assert jdomain.getUpperBound() == 1000.0
4444

4545

46+
def test_categorical_numeric_domain_list():
47+
"""Test create numeric domain from list"""
48+
domain = [0, 1000]
49+
jdomain = feature_domain(domain)
50+
assert jdomain.getCategories().size() == 2
51+
assert jdomain.getCategories().containsAll(domain)
52+
53+
domain = [0.0, 1000.0]
54+
jdomain = feature_domain(domain)
55+
assert jdomain.getCategories().size() == 2
56+
assert jdomain.getCategories().containsAll(domain)
57+
58+
59+
def test_categorical_object_domain_list():
60+
"""Test create object domain from list"""
61+
domain = [True, False]
62+
jdomain = feature_domain(domain)
63+
assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain"
64+
assert jdomain.getCategories().size() == 2
65+
assert jdomain.getCategories().containsAll(domain)
66+
67+
68+
def test_categorical_object_domain_list_2():
69+
"""Test create object domain from list"""
70+
domain = [b"test", b"test2"]
71+
jdomain = feature_domain(domain)
72+
assert str(jdomain.getClass().getSimpleName()) == "ObjectFeatureDomain"
73+
assert jdomain.getCategories().size() == 2
74+
assert jdomain.getCategories().containsAll(domain)
75+
76+
4677
def test_empty_domain():
4778
"""Test empty domain"""
4879
domain = feature_domain(None)
@@ -51,7 +82,7 @@ def test_empty_domain():
5182

5283
def test_categorical_domain_tuple():
5384
"""Test create categorical domain from tuple and list"""
54-
domain = ("foo", "bar", "baz")
85+
domain = ["foo", "bar", "baz"]
5586
jdomain = feature_domain(domain)
5687
assert jdomain.getCategories().size() == 3
5788
assert jdomain.getCategories().containsAll(list(domain))
@@ -61,7 +92,6 @@ def test_categorical_domain_tuple():
6192
assert jdomain.getCategories().size() == 3
6293
assert jdomain.getCategories().containsAll(domain)
6394

64-
6595
def test_feature_function():
6696
"""Test helper method to create features"""
6797
f1 = feature(name="f-1", value=1.0, dtype="number")

0 commit comments

Comments
 (0)