Skip to content

Commit aea7a68

Browse files
committed
Added a few more tests for assess_bias and also some slight optimizations
1 parent 3105ac6 commit aea7a68

File tree

2 files changed

+231
-37
lines changed

2 files changed

+231
-37
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 83 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
# Package Imports
2222
from sasctl.pzmm.write_score_code import ScoreCode as sc
2323
from ..core import current_session
24-
from ..utils.decorators import deprecated
24+
from ..utils.decorators import deprecated, experimental
2525
from ..utils.misc import check_if_jupyter
2626

2727
try:
@@ -945,6 +945,20 @@ def assess_model_bias(
945945
def format_max_differences(
946946
maxdiff_dfs: List[DataFrame], datarole: str = "TEST"
947947
) -> DataFrame:
948+
"""
949+
Converts a list of max differences DataFrames into a singular DataFrame
950+
Parameters
951+
----------
952+
maxdiff_dfs: List[DataFrame]
953+
A list of max_differences DataFrames returned by CAS
954+
datarole : string, optional
955+
The data being used to assess bias (i.e. 'TEST', 'VALIDATION', etc.). Default is 'TEST.'
956+
957+
Returns
958+
-------
959+
DataFrame
960+
A singluar DataFrame containing all max differences data
961+
"""
948962
maxdiff_df = pd.concat(maxdiff_dfs)
949963
maxdiff_df = maxdiff_df.rename(
950964
columns={"Value": "maxdiff", "Base": "BASE", "Compare": "COMPARE"}
@@ -965,6 +979,28 @@ def format_group_metrics(
965979
pred_values: str = None,
966980
datarole: str = "TEST",
967981
) -> DataFrame:
982+
"""
983+
Converts list of group metrics DataFrames to a single DataFrame
984+
Parameters
985+
----------
986+
groupmetrics_dfs: List[DataFrame]
987+
List of group metrics DataFrames generated by CASAction
988+
pred_values : string, required for regression problems, otherwise not used
989+
Variable name containing the predicted values in score_table. The variable name must follow SAS naming
990+
conventions (no spaces and the name cannot begin with a number or symbol).Required for regression problems.
991+
The default value is None.
992+
prob_values : list of strings, required for classification problems, otherwise not used
993+
A list of variable names containing the predicted probability values in the score table. The first element
994+
should represent the predicted probability of the target class. Required for classification problems. Default
995+
is None.
996+
datarole : string, optional
997+
The data being used to assess bias (i.e. 'TEST', 'VALIDATION', etc.). Default is 'TEST.'
998+
999+
Returns
1000+
-------
1001+
DataFrame
1002+
A singular DataFrame containing formatted data for group metrics
1003+
"""
9681004
# adding group metrics dataframes and adding values/ formatting
9691005
groupmetrics_df = pd.concat(groupmetrics_dfs)
9701006
groupmetrics_df = groupmetrics_df.rename(
@@ -1006,6 +1042,7 @@ def format_group_metrics(
10061042
return groupmetrics_df
10071043

10081044
# TODO: Add doc_string and unit tests
1045+
@experimental
10091046
@classmethod
10101047
def bias_dataframes_to_json(
10111048
cls,
@@ -1018,6 +1055,42 @@ def bias_dataframes_to_json(
10181055
pred_values: str = None,
10191056
json_path: Union[str, Path, None] = None,
10201057
):
1058+
"""
1059+
Properly formats data from FairAITools CAS Action Set into a JSON readable formats
1060+
Parameters
1061+
----------
1062+
groupmetrics: DataFrame
1063+
A DataFrame containing the group metrics data
1064+
maxdifference: DataFrame
1065+
A DataFrame containing the max difference data
1066+
n_sensitivevariables: int
1067+
The total number of sensitive values
1068+
actual_values : String
1069+
Variable name containing the actual values in score_table. The variable name must follow SAS naming
1070+
conventions (no spaces and the name cannot begin with a number or symbol).
1071+
prob_values : list of strings, required for classification problems, otherwise not used
1072+
A list of variable names containing the predicted probability values in the score table. The first element
1073+
should represent the predicted probability of the target class. Required for classification problems. Default
1074+
is None.
1075+
levels: List of strings, required for classification problems, otherwise not used
1076+
List of classes of a nominal target in the order they were passed in prob_values. Levels must be passed as a
1077+
string. Default is None.
1078+
pred_values : string, required for regression problems, otherwise not used
1079+
Variable name containing the predicted values in score_table. The variable name must follow SAS naming
1080+
conventions (no spaces and the name cannot begin with a number or symbol).Required for regression problems.
1081+
The default value is None.
1082+
json_path : str or Path, optional
1083+
Location for the output JSON files. If a path is passed, the json files will populate in the directory and
1084+
the function will return None, unless return_dataframes is True. Otherwise, the function will return the json
1085+
strings in a dictionary (dict["maxDifferences.json"] and dict["groupMetrics.json"]). The default value is
1086+
None.
1087+
1088+
Returns
1089+
-------
1090+
dict
1091+
Dictionary containing a key-value pair representing the files name and json
1092+
dumps respectively.
1093+
"""
10211094
folder = "reg_jsons" if prob_values is None else "clf_jsons"
10221095

10231096
dfs = (maxdifference, groupmetrics)
@@ -1061,20 +1134,20 @@ def bias_dataframes_to_json(
10611134
"type": "num",
10621135
"values": [prob_label],
10631136
}
1064-
json_dict[1]["parameterMap"] = cls.add_dict_key(
1065-
dict=json_dict[1]["parameterMap"],
1066-
pos=i + 3,
1067-
new_key=prob_label,
1068-
new_value=paramdict,
1069-
)
1137+
json_dict[1]["parameterMap"][prob_label] = paramdict
1138+
# cls.add_dict_key(
1139+
# dict=json_dict[1]["parameterMap"],
1140+
# pos=i + 3,
1141+
# new_key=prob_label,
1142+
# new_value=paramdict,]
1143+
# )
10701144

10711145
else:
10721146
json_dict[1]["parameterMap"]["predict"]["label"] = pred_values
10731147
json_dict[1]["parameterMap"]["predict"]["parameter"] = pred_values
10741148
json_dict[1]["parameterMap"]["predict"]["values"] = [pred_values]
1075-
json_dict[1]["parameterMap"] = cls.rename_dict_key(
1076-
json_dict[1]["parameterMap"], pred_values, "predict"
1077-
)
1149+
json_dict[1]["parameterMap"][pred_values] = json_dict[1]["parameterMap"]["predict"]
1150+
del json_dict[1]["parameterMap"]["predict"]
10781151

10791152
if json_path:
10801153
for i, name in enumerate([MAXDIFFERENCES, GROUPMETRICS]):
@@ -1091,34 +1164,7 @@ def bias_dataframes_to_json(
10911164
GROUPMETRICS: json.dumps(json_dict[1], indent=4, cls=NpEncoder),
10921165
}
10931166

1094-
# TODO: Add doc_string and unit tests
1095-
@staticmethod
1096-
def add_dict_key(
1097-
dict: dict, pos: int, new_key: Union[str, int, float, bool], new_value
1098-
):
1099-
result = {}
1100-
for i, k in enumerate(dict.keys()):
1101-
if i == pos:
1102-
result[new_key] = new_value
1103-
result[k] = dict[k]
1104-
else:
1105-
result[k] = dict[k]
1106-
return result
11071167

1108-
# TODO: Add doc_string and unit tests
1109-
@staticmethod
1110-
def rename_dict_key(
1111-
dict: dict,
1112-
new_key: Union[str, int, float, bool],
1113-
old_key: Union[str, int, float, bool],
1114-
) -> dict:
1115-
result = {}
1116-
for k, v in dict.items():
1117-
if k == old_key:
1118-
result[new_key] = v
1119-
else:
1120-
result.update({k: v})
1121-
return result
11221168

11231169
@classmethod
11241170
def calculate_model_statistics(

tests/unit/test_write_json_files.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
import sasctl.pzmm as pzmm
2424
from sasctl.pzmm.write_json_files import JSONFiles as jf
25+
from sasctl.core import Session
2526

2627
# Example input variable list from hmeq dataset (generated by mlflow_model.py)
2728
input_dict = [
@@ -644,3 +645,150 @@ def test_create_requirements_json(change_dir):
644645
unittest.TestCase().assertCountEqual(
645646
json.loads(json_dict["requirements.json"]), expected
646647
)
648+
649+
class TestAssessBiasHelpers(unittest.TestCase):
650+
md_1 = pd.DataFrame({"Value": [0], "Base": ["A"], "Compare": ["C"]})
651+
md_2 = pd.DataFrame({"Value": [1], "Base": ["B"], "Compare": ["C"]})
652+
gm_1 = pd.DataFrame({
653+
"Group": ['A'],
654+
"N": [0],
655+
"MISCEVENT": ["E"],
656+
"MISCEVENTKS": ["F"],
657+
"cutoffKS": [.5],
658+
"PREDICTED": [1],
659+
"maxKS": [100]
660+
})
661+
gm_2 = pd.DataFrame({
662+
"Group": ['B'],
663+
"N": [1],
664+
"MISCEVENT": ["G"],
665+
"MISCEVENTKS": ["H"],
666+
"cutoffKS": [.2],
667+
"PREDICTED": [0],
668+
"maxKS": [500]
669+
})
670+
671+
def test_max_differences(self):
672+
md_2_copy = self.md_2.copy()
673+
md_2_copy = md_2_copy.set_index(pd.Index([1]))
674+
dfs = [self.md_1.copy(), md_2_copy]
675+
datarole = "role"
676+
return_table = jf.format_max_differences(dfs, datarole)
677+
pd.testing.assert_frame_equal(
678+
return_table,
679+
pd.DataFrame({
680+
"BASE": ["A", "B"],
681+
"COMPARE": ["C", "C"],
682+
"VLABEL": ["", ""],
683+
"_DATAROLE_": ["role", "role"],
684+
"maxdiff": [0, 1]
685+
})
686+
)
687+
688+
return_table = jf.format_max_differences(dfs)
689+
pd.testing.assert_frame_equal(
690+
return_table,
691+
pd.DataFrame({
692+
"BASE": ["A", "B"],
693+
"COMPARE": ["C", "C"],
694+
"VLABEL": ["", ""],
695+
"_DATAROLE_": ["TEST", "TEST"],
696+
"maxdiff": [0, 1]
697+
})
698+
)
699+
700+
def test_group_metrics(self):
701+
gm_2_copy = self.gm_2.copy()
702+
gm_2_copy = gm_2_copy.set_index(pd.Index([1]))
703+
dfs = [self.gm_1.copy(), gm_2_copy.copy()]
704+
prob_values = ['VarA', 'VarB']
705+
for i in range(len(dfs)):
706+
dfs[i][prob_values[0]] = [i]
707+
dfs[i][prob_values[1]] = [i+2]
708+
709+
gm = jf.format_group_metrics(dfs, prob_values)
710+
711+
pd.testing.assert_frame_equal(
712+
gm,
713+
pd.DataFrame({
714+
"LEVEL": ['A', 'B'],
715+
"VLABEL": ["", ""],
716+
"VarA": [0, 1],
717+
"VarB": [2, 3],
718+
"_DATAROLE_": ["TEST", "TEST"],
719+
"_avgyhat_": [1, 0],
720+
"_ks_": [100, 500],
721+
"_kscut_": [.5, .2],
722+
"_misccutoff_": ["E", "G"],
723+
"_miscks_": ["F", "H"],
724+
"_nobs_": [0, 1]
725+
})
726+
)
727+
728+
dfs_1 = [self.gm_1.copy(), gm_2_copy.copy()]
729+
prob_values = ['VarA', 'VarB']
730+
for i in range(len(dfs)):
731+
dfs_1[i][prob_values[0]] = [i]
732+
dfs_1[i][prob_values[1]] = [i + 2]
733+
734+
gm_1 = jf.format_group_metrics(dfs_1, prob_values, datarole="NEW")
735+
736+
pd.testing.assert_frame_equal(
737+
gm_1,
738+
pd.DataFrame({
739+
"LEVEL": ['A', 'B'],
740+
"VLABEL": ["", ""],
741+
"VarA": [0, 1],
742+
"VarB": [2, 3],
743+
"_DATAROLE_": ["NEW", "NEW"],
744+
"_avgyhat_": [1, 0],
745+
"_ks_": [100, 500],
746+
"_kscut_": [.5, .2],
747+
"_misccutoff_": ["E", "G"],
748+
"_miscks_": ["F", "H"],
749+
"_nobs_": [0, 1]
750+
})
751+
)
752+
753+
dfs_2 = [self.gm_1.copy(), gm_2_copy.copy()]
754+
pred_value = "Pred"
755+
for i in range(len(dfs)):
756+
dfs_2[i][pred_value] = [i]
757+
758+
gm_2 = jf.format_group_metrics(dfs_2, pred_values=pred_value, datarole="NEW")
759+
760+
pd.testing.assert_frame_equal(
761+
gm_2,
762+
pd.DataFrame({
763+
"LEVEL": ['A', 'B'],
764+
"Pred": [0,1],
765+
"VLABEL": ["", ""],
766+
"_DATAROLE_": ["NEW", "NEW"],
767+
"_avgyhat_": [1, 0],
768+
"_ks_": [100, 500],
769+
"_kscut_": [.5, .2],
770+
"_misccutoff_": ["E", "G"],
771+
"_miscks_": ["F", "H"],
772+
"_nobs_": [0, 1]
773+
})
774+
)
775+
776+
class TestAssessBias(unittest.TestCase):
777+
778+
def test_errors(self):
779+
with unittest.mock.patch("sasctl.core.Session._get_authorization_token"):
780+
with unittest.mock.patch("sasctl.core.Session.as_swat") as swat:
781+
with Session("host", "username", "password") as s:
782+
score_table = pd.DataFrame({"1nvalid": ["no."]})
783+
sensitive_values = "s"
784+
actual_values = "a"
785+
with pytest.raises(SyntaxError):
786+
jf.assess_model_bias(score_table, sensitive_values, actual_values)
787+
788+
score_table = pd.DataFrame({"valid": ["yes"]})
789+
with pytest.raises(ValueError):
790+
jf.assess_model_bias(score_table, sensitive_values, actual_values)
791+
792+
swat.side_effect = ImportError('oops')
793+
with pytest.raises(RuntimeError):
794+
jf.assess_model_bias(score_table, sensitive_values, actual_values)

0 commit comments

Comments
 (0)