Skip to content

Commit f6366c4

Browse files
authored
FAI-911: Improve API to declare selections in Python fairness (#132)
* Add data_conversion support for favorable types * Fix formatting
1 parent 583bda5 commit f6366c4

File tree

1 file changed

+22
-11
lines changed

1 file changed

+22
-11
lines changed

src/trustyai/metrics/fairness/group.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,20 @@
66
from jpype import JInt
77
from org.kie.trustyai.explainability.metrics import FairnessMetrics
88

9-
from trustyai.model import Output, Value, PredictionProvider, Model
10-
from trustyai.utils.data_conversions import pandas_to_trusty
9+
from trustyai.model import Value, PredictionProvider, Model
10+
from trustyai.utils.data_conversions import (
11+
pandas_to_trusty,
12+
OneOutputUnionType,
13+
one_output_convert,
14+
)
1115

1216
ColumSelector = Union[List[int], List[str]]
1317

1418

1519
def _column_selector_to_index(columns: ColumSelector, dataframe: pd.DataFrame):
20+
if len(columns) == 0:
21+
raise ValueError("Must specify at least one column")
22+
1623
if isinstance(columns[0], str): # passing column
1724
columns = dataframe.columns.get_indexer(columns)
1825
indices = [JInt(c) for c in columns] # Java casting
@@ -22,14 +29,15 @@ def _column_selector_to_index(columns: ColumSelector, dataframe: pd.DataFrame):
2229
def statistical_parity_difference(
2330
privileged: pd.DataFrame,
2431
unprivileged: pd.DataFrame,
25-
favorable: List[Output],
32+
favorable: OneOutputUnionType,
2633
outputs: Optional[List[int]] = None,
2734
) -> float:
2835
"""Calculate Statistical Parity Difference between privileged and unprivileged dataframes"""
36+
favorable_prediction_object = one_output_convert(favorable)
2937
return FairnessMetrics.groupStatisticalParityDifference(
3038
pandas_to_trusty(privileged, outputs),
3139
pandas_to_trusty(unprivileged, outputs),
32-
favorable,
40+
favorable_prediction_object.outputs,
3341
)
3442

3543

@@ -39,31 +47,33 @@ def statistical_parity_difference_model(
3947
model: Union[PredictionProvider, Model],
4048
privilege_columns: ColumSelector,
4149
privilege_values: List[Any],
42-
favorable: List[Output],
50+
favorable: OneOutputUnionType,
4351
) -> float:
4452
"""Calculate Statistical Parity Difference using a samples dataframe and a model"""
53+
favorable_prediction_object = one_output_convert(favorable)
4554
_privilege_values = [Value(v) for v in privilege_values]
4655
_jsamples = pandas_to_trusty(samples, no_outputs=True)
4756
return FairnessMetrics.groupStatisticalParityDifference(
4857
_jsamples,
4958
model,
5059
_column_selector_to_index(privilege_columns, samples),
5160
_privilege_values,
52-
favorable,
61+
favorable_prediction_object.outputs,
5362
)
5463

5564

5665
def disparate_impact_ratio(
5766
privileged: pd.DataFrame,
5867
unprivileged: pd.DataFrame,
59-
favorable: List[Output],
68+
favorable: OneOutputUnionType,
6069
outputs: Optional[List[int]] = None,
6170
) -> float:
6271
"""Calculate Disparate Impact Ration between privileged and unprivileged dataframes"""
72+
favorable_prediction_object = one_output_convert(favorable)
6373
return FairnessMetrics.groupDisparateImpactRatio(
6474
pandas_to_trusty(privileged, outputs),
6575
pandas_to_trusty(unprivileged, outputs),
66-
favorable,
76+
favorable_prediction_object.outputs,
6777
)
6878

6979

@@ -73,17 +83,18 @@ def disparate_impact_ratio_model(
7383
model: Union[PredictionProvider, Model],
7484
privilege_columns: ColumSelector,
7585
privilege_values: List[Any],
76-
favorable: List[Output],
86+
favorable: OneOutputUnionType,
7787
) -> float:
7888
"""Calculate Disparate Impact Ration using a samples dataframe and a model"""
89+
favorable_prediction_object = one_output_convert(favorable)
7990
_privilege_values = [Value(v) for v in privilege_values]
8091
_jsamples = pandas_to_trusty(samples, no_outputs=True)
8192
return FairnessMetrics.groupDisparateImpactRatio(
8293
_jsamples,
8394
model,
8495
_column_selector_to_index(privilege_columns, samples),
8596
_privilege_values,
86-
favorable,
97+
favorable_prediction_object.outputs,
8798
)
8899

89100

@@ -92,7 +103,7 @@ def average_odds_difference(
92103
test: pd.DataFrame,
93104
truth: pd.DataFrame,
94105
privilege_columns: ColumSelector,
95-
privilege_values: List[Any],
106+
privilege_values: OneOutputUnionType,
96107
positive_class: List[Any],
97108
outputs: Optional[List[int]] = None,
98109
) -> float:

0 commit comments

Comments
 (0)