@@ -58,6 +58,7 @@ class NpEncoder(json.JSONEncoder):
58
58
MAXDIFFERENCES = "maxDifferences.json"
59
59
GROUPMETRICS = "groupMetrics.json"
60
60
VARIMPORTANCES = 'dmcas_relativeimportance.json'
61
+ MISC = 'dmcas_misc.json'
61
62
62
63
63
64
def _flatten (nested_list : Iterable ) -> Generator [Any , None , None ]:
@@ -1174,7 +1175,8 @@ def calculate_model_statistics(
1174
1175
train_data : Union [DataFrame , List [list ], Type ["numpy.array" ]] = None ,
1175
1176
test_data : Union [DataFrame , List [list ], Type ["numpy.array" ]] = None ,
1176
1177
json_path : Union [str , Path , None ] = None ,
1177
- target_type : str = "classification"
1178
+ target_type : str = "classification" ,
1179
+ cutoff : Optional [float ] = None
1178
1180
) -> Union [dict , None ]:
1179
1181
"""
1180
1182
Calculates fit statistics (including ROC and Lift curves) from datasets and then
@@ -2345,6 +2347,12 @@ def generate_model_card(
2345
2347
class_vars ,
2346
2348
caslib
2347
2349
)
2350
+
2351
+ # Generates dmcas_misc.json file
2352
+ cls .generate_misc (
2353
+ conn ,
2354
+ model_files
2355
+ )
2348
2356
2349
2357
@staticmethod
2350
2358
def upload_training_data (
@@ -2675,4 +2683,75 @@ def generate_variable_importance(
2675
2683
print (
2676
2684
f"{ VARIMPORTANCES } was successfully written and saved to "
2677
2685
f"{ Path (model_files ) / VARIMPORTANCES } "
2686
+ )
2687
+
2688
+ @classmethod
2689
+ def generate_misc (
2690
+ cls ,
2691
+ model_files : Union [str , Path , dict ]
2692
+ ):
2693
+ """
2694
+ Generates the dmcas_relativeimportance.json file, which is used to determine variable importance
2695
+
2696
+ Parameters
2697
+ ----------
2698
+ conn
2699
+ A SWAT connection used to connect to the user's CAS server
2700
+ model_files : string, Path, or dict
2701
+ Either the directory location of the model files (string or Path object), or
2702
+ a dictionary containing the contents of all the model files.
2703
+ """
2704
+ if isinstance (model_files , dict ):
2705
+ if ROC not in model_files :
2706
+ raise RuntimeError (
2707
+ "The ModelProperties.json file must be generated before the model card data "
2708
+ "can be generated."
2709
+ )
2710
+ roc_table = model_files [ROC ]
2711
+ else :
2712
+ if not Path .exists (Path (model_files ) / ROC ):
2713
+ raise RuntimeError (
2714
+ "The ModelProperties.json file must be generated before the model card data "
2715
+ "can be generated."
2716
+ )
2717
+ with open (Path (model_files ) / ROC , 'r' ) as roc_file :
2718
+ roc_table = json .load (roc_file )
2719
+ correct_text = ["CORRECT" , "INCORRECT" , "CORRECT" , "INCORRECT" ]
2720
+ outcome_values = ['1' , '0' , '0' , '1' ]
2721
+ misc_data = list ()
2722
+ # Iterates through ROC table to get TRAIN, TEST, and VALIDATE data with a cutoff of .5
2723
+ for i in range (50 , 300 , 100 ):
2724
+ roc_data = roc_table ['data' ][i ]['dataMap' ]
2725
+ correctness_values = [roc_data ['_TP_' ], roc_data ['_FP_' ], roc_data ['_TN_' ], roc_data ['_FN_' ]]
2726
+ for (c_text , c_val , o_val ) in zip (correct_text , correctness_values , outcome_values ):
2727
+ misc_data .append ({
2728
+ "CorrectText" : c_text ,
2729
+ "Outcome" : o_val ,
2730
+ "_Count_" : c_val ,
2731
+ "_DataRole_" : roc_data ['_DataRole_' ],
2732
+ "_cutoffSource_" : "Default" ,
2733
+ "_cutoff_" : "0.5"
2734
+ })
2735
+
2736
+ json_template_path = (
2737
+ Path (__file__ ).resolve ().parent / f"template_files/{ MISC } "
2738
+ )
2739
+ with open (json_template_path , 'r' ) as f :
2740
+ misc_json = json .load (f )
2741
+ misc_json ['data' ] = misc_data
2742
+
2743
+ if isinstance (model_files , dict ):
2744
+ model_files [MISC ] = json .dumps (misc_json , indent = 4 , cls = NpEncoder )
2745
+ if cls .notebook_output :
2746
+ print (
2747
+ f"{ MISC } was successfully written and saved to "
2748
+ f"model files dictionary."
2749
+ )
2750
+ else :
2751
+ with open (Path (model_files ) / MISC , 'w' ) as json_file :
2752
+ json_file .write (json .dumps (misc_json , indent = 4 , cls = NpEncoder ))
2753
+ if cls .notebook_output :
2754
+ print (
2755
+ f"{ MISC } was successfully written and saved to "
2756
+ f"{ Path (model_files ) / MISC } "
2678
2757
)
0 commit comments