16
16
import warnings
17
17
from pathlib import Path
18
18
from unittest .mock import patch
19
+ import math
19
20
20
21
import numpy as np
21
22
import pandas as pd
22
23
import pytest
23
24
from sklearn .model_selection import train_test_split
25
+ from sklearn import datasets
26
+ from sklearn .linear_model import LogisticRegression
24
27
from sklearn .tree import DecisionTreeClassifier
25
28
26
29
import sasctl .pzmm as pzmm
43
46
{"name" : "REASON_HomeImp" , "type" : "integer" },
44
47
]
45
48
49
+ class BadModel :
50
+ attr = None
51
+
52
+ @pytest .fixture
53
+ def bad_model ():
54
+ return BadModel ()
55
+
56
+
57
+ @pytest .fixture
58
+ def train_data ():
59
+ """Returns the Iris data set as (X, y)"""
60
+ raw = datasets .load_iris ()
61
+ iris = pd .DataFrame (raw .data , columns = raw .feature_names )
62
+ iris = iris .join (pd .DataFrame (raw .target ))
63
+ iris .columns = ["SepalLength" , "SepalWidth" , "PetalLength" , "PetalWidth" , "Species" ]
64
+ iris ["Species" ] = iris ["Species" ].astype ("category" )
65
+ iris .Species .cat .categories = raw .target_names
66
+ return iris .iloc [:, 0 :4 ], iris ["Species" ]
67
+
68
+
69
+ @pytest .fixture
70
+ def sklearn_model (train_data ):
71
+ """Returns a simple Scikit-Learn model"""
72
+ X , y = train_data
73
+ with warnings .catch_warnings ():
74
+ warnings .simplefilter ("ignore" )
75
+ model = LogisticRegression (
76
+ multi_class = "multinomial" , solver = "lbfgs" , max_iter = 1000
77
+ )
78
+ model .fit (X , y )
79
+ return model
46
80
47
81
@pytest .fixture
48
82
def change_dir ():
@@ -849,3 +883,148 @@ def test_errors(self):
849
883
jf .assess_model_bias (
850
884
score_table , sensitive_values , actual_values
851
885
)
886
+
887
+
888
+ class TestModelCardGeneration (unittest .TestCase ):
889
+ def test_generate_outcome_average_interval (self ):
890
+ df = pd .DataFrame ({"input" : [3 , 2 , 1 ], "output" : [1 , 2 , 3 ]})
891
+ assert (
892
+ jf .generate_outcome_average (df , ["input" ], "interval" ) ==
893
+ {'eventAverage' : 2.0 }
894
+ )
895
+
896
+ def test_generate_outcome_average_classification (self ):
897
+ df = pd .DataFrame ({"input" : [3 , 2 ], "output" : [0 , 1 ]})
898
+ event_percentage = jf .generate_outcome_average (df , ["input" ], "classification" , 1 )
899
+ assert ('eventPercentage' in event_percentage )
900
+
901
+ def test_generate_outcome_average_interval_non_numeric_output (self ):
902
+ df = pd .DataFrame ({"input" : [3 , 2 , 1 ], "output" : ["one" , "two" , "three" ]})
903
+ with pytest .raises (ValueError ):
904
+ jf .generate_outcome_average (df , ["input" ], "interval" )
905
+
906
+
907
+ class TestGetSelectionStatisticValue (unittest .TestCase ):
908
+ model_file_dict = {
909
+ "dmcas_fitstat.json" : {
910
+ "data" : [
911
+ {
912
+ "dataMap" : {
913
+ "_GINI_" : 1 ,
914
+ "_C_" : 2 ,
915
+ "_TAU_" : None ,
916
+ "_DataRole_" : "TRAIN"
917
+ }
918
+ }
919
+ ]
920
+ }
921
+ }
922
+ tmp_dir = tempfile .TemporaryDirectory ()
923
+ with open (Path (tmp_dir .name ) / "dmcas_fitstat.json" , "w+" ) as f :
924
+ f .write (json .dumps (model_file_dict ['dmcas_fitstat.json' ]))
925
+
926
+ def test_get_statistic_dict_default (self ):
927
+ selection_statistic = jf .get_selection_statistic_value (self .model_file_dict )
928
+ assert (selection_statistic == 1 )
929
+
930
+ def test_get_statistic_dict_custom (self ):
931
+ selection_statistic = jf .get_selection_statistic_value (self .model_file_dict , "_C_" )
932
+ assert (selection_statistic == 2 )
933
+
934
+ def test_get_blank_statistic_dict (self ):
935
+ with pytest .raises (RuntimeError ):
936
+ jf .get_selection_statistic_value (self .model_file_dict , "_TAU_" )
937
+
938
+ def test_get_statistics_path_default (self ):
939
+ selection_statistic = jf .get_selection_statistic_value (Path (self .tmp_dir .name ))
940
+ assert (selection_statistic == 1 )
941
+
942
+ def test_get_statistics_path_custom (self ):
943
+ selection_statistic = jf .get_selection_statistic_value (Path (self .tmp_dir .name ), "_C_" )
944
+ assert (selection_statistic == 2 )
945
+
946
+ def test_get_blank_statistic_path (self ):
947
+ with pytest .raises (RuntimeError ):
948
+ jf .get_selection_statistic_value (Path (self .tmp_dir .name ), "_TAU_" )
949
+
950
+ def test_get_statistics_str_default (self ):
951
+ selection_statistic = jf .get_selection_statistic_value (self .tmp_dir .name )
952
+ assert (selection_statistic == 1 )
953
+
954
+ def test_get_statistics_str_custom (self ):
955
+ selection_statistic = jf .get_selection_statistic_value (self .tmp_dir .name , "_C_" )
956
+ assert (selection_statistic == 2 )
957
+
958
+ def test_get_blank_statistic_str (self ):
959
+ with pytest .raises (RuntimeError ):
960
+ jf .get_selection_statistic_value (self .tmp_dir .name , "_TAU_" )
961
+
962
+
963
+ class TestUpdateModelProperties (unittest .TestCase ):
964
+ def setUp (self ):
965
+ self .model_file_dict = {
966
+ "ModelProperties.json" :
967
+ {
968
+ "example" : "property"
969
+ }
970
+ }
971
+ self .tmp_dir = tempfile .TemporaryDirectory ()
972
+ with open (Path (self .tmp_dir .name ) / "ModelProperties.json" , "w+" ) as f :
973
+ f .write (json .dumps (self .model_file_dict ['ModelProperties.json' ]))
974
+
975
+ def tearDown (self ):
976
+ self .tmp_dir .cleanup ()
977
+
978
+ def test_update_model_properties_dict (self ):
979
+ update_dict = {'new' : 'arg' , 'newer' : 'thing' }
980
+ jf .update_model_properties (self .model_file_dict , update_dict )
981
+ assert (self .model_file_dict ['ModelProperties.json' ]['example' ] == 'property' )
982
+ assert (self .model_file_dict ['ModelProperties.json' ]['new' ] == 'arg' )
983
+ assert (self .model_file_dict ['ModelProperties.json' ]['newer' ] == 'thing' )
984
+
985
+ def test_update_model_properties_dict_overwrite (self ):
986
+ update_dict = {'new' : 'arg' , 'example' : 'thing' }
987
+ jf .update_model_properties (self .model_file_dict , update_dict )
988
+ assert (self .model_file_dict ['ModelProperties.json' ]['example' ] == 'thing' )
989
+ assert (self .model_file_dict ['ModelProperties.json' ]['new' ] == 'arg' )
990
+
991
+ def test_update_model_properties_dict_number (self ):
992
+ update_dict = {"number" : 1 }
993
+ jf .update_model_properties (self .model_file_dict , update_dict )
994
+ assert (self .model_file_dict ['ModelProperties.json' ]['number' ] == '1' )
995
+
996
+ def test_update_model_properties_dict_round_number (self ):
997
+ update_dict = {'number' : 0.123456789012345 }
998
+ jf .update_model_properties (self .model_file_dict , update_dict )
999
+ assert (self .model_file_dict ['ModelProperties.json' ]['number' ] == '0.12345678901234' )
1000
+
1001
+ def test_update_model_properties_str (self ):
1002
+ update_dict = {'new' : 'arg' , 'newer' : 'thing' }
1003
+ jf .update_model_properties (self .tmp_dir .name , update_dict )
1004
+ with open (Path (self .tmp_dir .name ) / 'ModelProperties.json' , 'r' ) as f :
1005
+ model_properties = json .load (f )
1006
+ assert (model_properties ['example' ] == 'property' )
1007
+ assert (model_properties ['new' ] == 'arg' )
1008
+ assert (model_properties ['newer' ] == 'thing' )
1009
+
1010
+ def test_update_model_properties_str_overwrite (self ):
1011
+ update_dict = {'new' : 'arg' , 'example' : 'thing' }
1012
+ jf .update_model_properties (self .tmp_dir .name , update_dict )
1013
+ with open (Path (self .tmp_dir .name ) / 'ModelProperties.json' , 'r' ) as f :
1014
+ model_properties = json .load (f )
1015
+ assert (model_properties ['example' ] == 'thing' )
1016
+ assert (model_properties ['new' ] == 'arg' )
1017
+
1018
+ def test_update_model_properties_str_number (self ):
1019
+ update_dict = {"number" : 1 }
1020
+ jf .update_model_properties (self .tmp_dir .name , update_dict )
1021
+ with open (Path (self .tmp_dir .name ) / 'ModelProperties.json' , 'r' ) as f :
1022
+ model_properties = json .load (f )
1023
+ assert (model_properties ['number' ] == '1' )
1024
+
1025
+ def test_update_model_properties_str_round_number (self ):
1026
+ update_dict = {'number' : 0.123456789012345 }
1027
+ jf .update_model_properties (self .tmp_dir .name , update_dict )
1028
+ with open (Path (self .tmp_dir .name ) / 'ModelProperties.json' , 'r' ) as f :
1029
+ model_properties = json .load (f )
1030
+ assert (model_properties ['number' ] == '0.12345678901234' )
0 commit comments