Skip to content

Commit 3913c30

Browse files
committed
Added some tests for model cards + fixed some model card errors
1 parent c20c162 commit 3913c30

File tree

2 files changed

+206
-28
lines changed

2 files changed

+206
-28
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 27 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2374,7 +2374,7 @@ def generate_outcome_average(
23742374
):
23752375
"""
23762376
Generates the outcome average of the training data. For Interval targets, the event average
2377-
is generated. For Classification targets, the event average is returned.
2377+
is generated. For Classification targets, the event percentage is returned.
23782378
23792379
Parameters
23802380
----------
@@ -2395,17 +2395,23 @@ def generate_outcome_average(
23952395
dict
23962396
Returns a dictionary with a key value pair that represents the outcome average.
23972397
"""
2398+
import numbers
23982399
output_var = train_data.drop(input_variables, axis=1)
23992400
if target_type == "classification":
24002401
value_counts = output_var[output_var.columns[0]].value_counts()
24012402
return {'eventPercentage': value_counts[target_value]/sum(value_counts)}
24022403
elif target_type == "interval":
2403-
return {'eventAverage': sum(value_counts[value_counts.columns[0]]) / len(value_counts)}
2404+
if not isinstance(output_var[output_var.columns[0]].iloc[0], numbers.Number):
2405+
raise ValueError("Detected output column is not numeric. Please ensure that " +
2406+
"the correct output column is being passed, and that no extra columns " +
2407+
"are in front of the output column. This function assumes that the first " +
2408+
"non-input column is the output column.jf")
2409+
return {'eventAverage': sum(output_var[output_var.columns[0]]) / len(output_var)}
24042410

24052411
@staticmethod
24062412
def get_selection_statistic_value(
2407-
model_files,
2408-
selection_statistic
2413+
model_files: Union[str, Path, dict],
2414+
selection_statistic: str = "_GINI_"
24092415
):
24102416
"""
24112417
Finds the value of the chosen selection statistic in dmcas_fitstat.json, which should have been
@@ -2493,10 +2499,11 @@ def update_model_properties(
24932499
)
24942500
with open(Path(model_files) / PROP, 'r+') as properties_json:
24952501
model_properties = json.load(properties_json)
2496-
if not isinstance(update_dict[key], str):
2497-
model_files[PROP][key] = str(round(update_dict[key], 14))
2498-
else:
2499-
model_files[PROP][key] = update_dict[key]
2502+
for key in update_dict:
2503+
if not isinstance(update_dict[key], str):
2504+
model_properties[key] = str(round(update_dict[key], 14))
2505+
else:
2506+
model_properties[key] = update_dict[key]
25002507
properties_json.seek(0)
25012508
properties_json.write(json.dumps(model_properties, indent=4, cls=NpEncoder))
25022509
properties_json.truncate()
@@ -2537,14 +2544,6 @@ def generate_variable_importance(
25372544
caslib: str, optional
25382545
The caslib the training data will be stored on. The default value is "Public"
25392546
"""
2540-
try:
2541-
sess = current_session()
2542-
conn = sess.as_swat()
2543-
except ImportError:
2544-
raise RuntimeError(
2545-
"The `swat` package is required to generate fit statistics, ROC, and "
2546-
"Lift charts with the calculate_model_statistics function."
2547-
)
25482547
# Remove target variable from training data by selecting only input variable columns
25492548
x_train_data = train_data[interval_vars + class_vars]
25502549
# Upload scored training data to run variable importance on
@@ -2573,12 +2572,12 @@ def generate_variable_importance(
25732572
"name": 'BIN',
25742573
"inputs": [{"name": var} for var in interval_vars],
25752574
"targets": [{"name": "Prediction"}],
2576-
"discretize":{
2577-
"method":method,
2578-
"arguments":{
2579-
"minNBins":1,
2580-
"maxNBins":8,
2581-
"treeCrit":treeCrit,
2575+
"discretize": {
2576+
"method": method,
2577+
"arguments": {
2578+
"minNBins": 1,
2579+
"maxNBins": 8,
2580+
"treeCrit": treeCrit,
25822581
"contingencyTblOpts":{"inputsMethod": 'BUCKET', "inputsNLevels": 100},
25832582
"overrides": {"minNObsInBin": 5, "binMissing": True, "noDataLowerUpperBound": True}
25842583
}
@@ -2589,12 +2588,12 @@ def generate_variable_importance(
25892588
"name": 'BIN_NOM',
25902589
"inputs": [{"name": var} for var in class_vars],
25912590
"targets": [{"name": "Prediction"}],
2592-
"catTrans":{
2593-
"method":method,
2594-
"arguments":{
2595-
"minNBins":1,
2596-
"maxNBins":8,
2597-
"treeCrit":treeCrit,
2591+
"catTrans": {
2592+
"method": method,
2593+
"arguments": {
2594+
"minNBins": 1,
2595+
"maxNBins": 8,
2596+
"treeCrit": treeCrit,
25982597
"overrides": {"minNObsInBin": 5, "binMissing": True}
25992598
}
26002599
}

tests/unit/test_write_json_files.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,14 @@
1616
import warnings
1717
from pathlib import Path
1818
from unittest.mock import patch
19+
import math
1920

2021
import numpy as np
2122
import pandas as pd
2223
import pytest
2324
from sklearn.model_selection import train_test_split
25+
from sklearn import datasets
26+
from sklearn.linear_model import LogisticRegression
2427
from sklearn.tree import DecisionTreeClassifier
2528

2629
import sasctl.pzmm as pzmm
@@ -43,6 +46,37 @@
4346
{"name": "REASON_HomeImp", "type": "integer"},
4447
]
4548

49+
class BadModel:
50+
attr = None
51+
52+
@pytest.fixture
53+
def bad_model():
54+
return BadModel()
55+
56+
57+
@pytest.fixture
58+
def train_data():
59+
"""Returns the Iris data set as (X, y)"""
60+
raw = datasets.load_iris()
61+
iris = pd.DataFrame(raw.data, columns=raw.feature_names)
62+
iris = iris.join(pd.DataFrame(raw.target))
63+
iris.columns = ["SepalLength", "SepalWidth", "PetalLength", "PetalWidth", "Species"]
64+
iris["Species"] = iris["Species"].astype("category")
65+
iris.Species.cat.categories = raw.target_names
66+
return iris.iloc[:, 0:4], iris["Species"]
67+
68+
69+
@pytest.fixture
70+
def sklearn_model(train_data):
71+
"""Returns a simple Scikit-Learn model"""
72+
X, y = train_data
73+
with warnings.catch_warnings():
74+
warnings.simplefilter("ignore")
75+
model = LogisticRegression(
76+
multi_class="multinomial", solver="lbfgs", max_iter=1000
77+
)
78+
model.fit(X, y)
79+
return model
4680

4781
@pytest.fixture
4882
def change_dir():
@@ -849,3 +883,148 @@ def test_errors(self):
849883
jf.assess_model_bias(
850884
score_table, sensitive_values, actual_values
851885
)
886+
887+
888+
class TestModelCardGeneration(unittest.TestCase):
889+
def test_generate_outcome_average_interval(self):
890+
df = pd.DataFrame({"input": [3, 2, 1], "output": [1, 2, 3]})
891+
assert (
892+
jf.generate_outcome_average(df, ["input"], "interval") ==
893+
{'eventAverage': 2.0}
894+
)
895+
896+
def test_generate_outcome_average_classification(self):
897+
df = pd.DataFrame({"input": [3, 2], "output": [0, 1]})
898+
event_percentage = jf.generate_outcome_average(df, ["input"], "classification", 1)
899+
assert('eventPercentage' in event_percentage)
900+
901+
def test_generate_outcome_average_interval_non_numeric_output(self):
902+
df = pd.DataFrame({"input": [3, 2, 1], "output": ["one", "two", "three"]})
903+
with pytest.raises(ValueError):
904+
jf.generate_outcome_average(df, ["input"], "interval")
905+
906+
907+
class TestGetSelectionStatisticValue(unittest.TestCase):
908+
model_file_dict = {
909+
"dmcas_fitstat.json": {
910+
"data": [
911+
{
912+
"dataMap": {
913+
"_GINI_": 1,
914+
"_C_": 2,
915+
"_TAU_": None,
916+
"_DataRole_": "TRAIN"
917+
}
918+
}
919+
]
920+
}
921+
}
922+
tmp_dir = tempfile.TemporaryDirectory()
923+
with open(Path(tmp_dir.name) / "dmcas_fitstat.json", "w+") as f:
924+
f.write(json.dumps(model_file_dict['dmcas_fitstat.json']))
925+
926+
def test_get_statistic_dict_default(self):
927+
selection_statistic = jf.get_selection_statistic_value(self.model_file_dict)
928+
assert(selection_statistic == 1)
929+
930+
def test_get_statistic_dict_custom(self):
931+
selection_statistic = jf.get_selection_statistic_value(self.model_file_dict, "_C_")
932+
assert(selection_statistic == 2)
933+
934+
def test_get_blank_statistic_dict(self):
935+
with pytest.raises(RuntimeError):
936+
jf.get_selection_statistic_value(self.model_file_dict, "_TAU_")
937+
938+
def test_get_statistics_path_default(self):
939+
selection_statistic = jf.get_selection_statistic_value(Path(self.tmp_dir.name))
940+
assert(selection_statistic == 1)
941+
942+
def test_get_statistics_path_custom(self):
943+
selection_statistic = jf.get_selection_statistic_value(Path(self.tmp_dir.name), "_C_")
944+
assert(selection_statistic == 2)
945+
946+
def test_get_blank_statistic_path(self):
947+
with pytest.raises(RuntimeError):
948+
jf.get_selection_statistic_value(Path(self.tmp_dir.name), "_TAU_")
949+
950+
def test_get_statistics_str_default(self):
951+
selection_statistic = jf.get_selection_statistic_value(self.tmp_dir.name)
952+
assert (selection_statistic == 1)
953+
954+
def test_get_statistics_str_custom(self):
955+
selection_statistic = jf.get_selection_statistic_value(self.tmp_dir.name, "_C_")
956+
assert (selection_statistic == 2)
957+
958+
def test_get_blank_statistic_str(self):
959+
with pytest.raises(RuntimeError):
960+
jf.get_selection_statistic_value(self.tmp_dir.name, "_TAU_")
961+
962+
963+
class TestUpdateModelProperties(unittest.TestCase):
964+
def setUp(self):
965+
self.model_file_dict = {
966+
"ModelProperties.json":
967+
{
968+
"example": "property"
969+
}
970+
}
971+
self.tmp_dir = tempfile.TemporaryDirectory()
972+
with open(Path(self.tmp_dir.name) / "ModelProperties.json", "w+") as f:
973+
f.write(json.dumps(self.model_file_dict['ModelProperties.json']))
974+
975+
def tearDown(self):
976+
self.tmp_dir.cleanup()
977+
978+
def test_update_model_properties_dict(self):
979+
update_dict = {'new': 'arg', 'newer': 'thing'}
980+
jf.update_model_properties(self.model_file_dict, update_dict)
981+
assert(self.model_file_dict['ModelProperties.json']['example'] == 'property')
982+
assert(self.model_file_dict['ModelProperties.json']['new'] == 'arg')
983+
assert(self.model_file_dict['ModelProperties.json']['newer'] == 'thing')
984+
985+
def test_update_model_properties_dict_overwrite(self):
986+
update_dict = {'new': 'arg', 'example': 'thing'}
987+
jf.update_model_properties(self.model_file_dict, update_dict)
988+
assert (self.model_file_dict['ModelProperties.json']['example'] == 'thing')
989+
assert (self.model_file_dict['ModelProperties.json']['new'] == 'arg')
990+
991+
def test_update_model_properties_dict_number(self):
992+
update_dict = {"number": 1}
993+
jf.update_model_properties(self.model_file_dict, update_dict)
994+
assert (self.model_file_dict['ModelProperties.json']['number'] == '1')
995+
996+
def test_update_model_properties_dict_round_number(self):
997+
update_dict = {'number': 0.123456789012345}
998+
jf.update_model_properties(self.model_file_dict, update_dict)
999+
assert (self.model_file_dict['ModelProperties.json']['number'] == '0.12345678901234')
1000+
1001+
def test_update_model_properties_str(self):
1002+
update_dict = {'new': 'arg', 'newer': 'thing'}
1003+
jf.update_model_properties(self.tmp_dir.name, update_dict)
1004+
with open(Path(self.tmp_dir.name) / 'ModelProperties.json', 'r') as f:
1005+
model_properties = json.load(f)
1006+
assert(model_properties['example'] == 'property')
1007+
assert(model_properties['new'] == 'arg')
1008+
assert(model_properties['newer'] == 'thing')
1009+
1010+
def test_update_model_properties_str_overwrite(self):
1011+
update_dict = {'new': 'arg', 'example': 'thing'}
1012+
jf.update_model_properties(self.tmp_dir.name, update_dict)
1013+
with open(Path(self.tmp_dir.name) / 'ModelProperties.json', 'r') as f:
1014+
model_properties = json.load(f)
1015+
assert (model_properties['example'] == 'thing')
1016+
assert (model_properties['new'] == 'arg')
1017+
1018+
def test_update_model_properties_str_number(self):
1019+
update_dict = {"number": 1}
1020+
jf.update_model_properties(self.tmp_dir.name, update_dict)
1021+
with open(Path(self.tmp_dir.name) / 'ModelProperties.json', 'r') as f:
1022+
model_properties = json.load(f)
1023+
assert (model_properties['number'] == '1')
1024+
1025+
def test_update_model_properties_str_round_number(self):
1026+
update_dict = {'number': 0.123456789012345}
1027+
jf.update_model_properties(self.tmp_dir.name, update_dict)
1028+
with open(Path(self.tmp_dir.name) / 'ModelProperties.json', 'r') as f:
1029+
model_properties = json.load(f)
1030+
assert (model_properties['number'] == '0.12345678901234')

0 commit comments

Comments
 (0)