3
3
from typing import Optional , Tuple , List , Union
4
4
5
5
from jpype import _jclass
6
+
6
7
from org .kie .trustyai .explainability .model .domain import (
7
8
FeatureDomain ,
8
9
NumericalFeatureDomain ,
9
10
CategoricalFeatureDomain ,
11
+ CategoricalNumericalFeatureDomain ,
12
+ ObjectFeatureDomain ,
10
13
EmptyFeatureDomain ,
11
14
)
12
15
13
16
14
- def feature_domain (
15
- values : Optional [Union [Tuple , List [str ]]]
16
- ) -> Optional [FeatureDomain ]:
17
+ def feature_domain (values : Optional [Union [Tuple , List ]]) -> Optional [FeatureDomain ]:
17
18
r"""Create a Java :class:`FeatureDomain`. This represents the valid range of values for a
18
19
particular feature, which is useful when constraining a counterfactual explanation to ensure it
19
20
only recovers valid inputs. For example, if we had a feature that described a person's age, we
@@ -22,13 +23,18 @@ def feature_domain(
22
23
23
24
Parameters
24
25
----------
25
- values : Optional[Union[Tuple, List[str] ]]
26
+ values : Optional[Union[Tuple, List]]
26
27
The valid values of the feature. If `values` takes the form of:
27
28
28
29
* **A tuple of floats or integers:** The feature domain will be a continuous range from
29
30
``values[0]`` to ``values[1]``.
30
- * **A list of strings:** The feature domain will be categorical, where `values` contains
31
- all possible valid feature values.
31
+ * **A list of floats or integers:**: The feature domain will be a *numeric* categorical,
32
+ where `values` contains all possible valid feature values.
33
+ * **A list of strings:** The feature domain will be a *string* categorical, where `values`
34
+ contains all possible valid feature values.
35
+ * **A list of objects:** The feature domain will be an *object* categorical, where `values`
36
+ contains all possible valid feature values. These may present an issue if the objects
37
+ are not natively Java serializable.
32
38
33
39
Otherwise, the feature domain will be taken as `Empty`, which will mean it will be held
34
40
fixed during the counterfactual explanation.
@@ -43,12 +49,29 @@ def feature_domain(
43
49
if not values :
44
50
domain = EmptyFeatureDomain .create ()
45
51
else :
46
- if isinstance (values [0 ], (float , int )):
47
- domain = NumericalFeatureDomain .create (values [0 ], values [1 ])
48
- elif isinstance (values [0 ], str ):
49
- domain = CategoricalFeatureDomain .create (
50
- _jclass .JClass ("java.util.Arrays" ).asList (values )
52
+ if isinstance (values , tuple ):
53
+ assert isinstance (values [0 ], (float , int )) and isinstance (
54
+ values [1 ], (float , int )
55
+ )
56
+ assert len (values ) == 2 , (
57
+ "Tuples passed as domain values must only contain"
58
+ " two values that define the (minimum, maximum) of the domain"
51
59
)
60
+ domain = NumericalFeatureDomain .create (values [0 ], values [1 ])
61
+
62
+ elif isinstance (values , list ):
63
+ java_array = _jclass .JClass ("java.util.Arrays" ).asList (values )
64
+ if isinstance (values [0 ], bool ) and isinstance (values [1 ], bool ):
65
+ domain = ObjectFeatureDomain .create (java_array )
66
+ elif isinstance (values [0 ], (float , int )) and isinstance (
67
+ values [1 ], (float , int )
68
+ ):
69
+ domain = CategoricalNumericalFeatureDomain .create (java_array )
70
+ elif isinstance (values [0 ], str ):
71
+ domain = CategoricalFeatureDomain .create (java_array )
72
+ else :
73
+ domain = ObjectFeatureDomain .create (java_array )
74
+
52
75
else :
53
76
domain = EmptyFeatureDomain .create ()
54
77
return domain
0 commit comments