Skip to content

Commit e7c3791

Browse files
committed
added dmcas_misc file generation capability
1 parent 7b14935 commit e7c3791

File tree

1 file changed

+80
-1
lines changed

1 file changed

+80
-1
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ class NpEncoder(json.JSONEncoder):
5858
MAXDIFFERENCES = "maxDifferences.json"
5959
GROUPMETRICS = "groupMetrics.json"
6060
VARIMPORTANCES = 'dmcas_relativeimportance.json'
61+
MISC = 'dmcas_misc.json'
6162

6263

6364
def _flatten(nested_list: Iterable) -> Generator[Any, None, None]:
@@ -1174,7 +1175,8 @@ def calculate_model_statistics(
11741175
train_data: Union[DataFrame, List[list], Type["numpy.array"]] = None,
11751176
test_data: Union[DataFrame, List[list], Type["numpy.array"]] = None,
11761177
json_path: Union[str, Path, None] = None,
1177-
target_type: str = "classification"
1178+
target_type: str = "classification",
1179+
cutoff: Optional[float] = None
11781180
) -> Union[dict, None]:
11791181
"""
11801182
Calculates fit statistics (including ROC and Lift curves) from datasets and then
@@ -2345,6 +2347,12 @@ def generate_model_card(
23452347
class_vars,
23462348
caslib
23472349
)
2350+
2351+
# Generates dmcas_misc.json file
2352+
cls.generate_misc(
2353+
conn,
2354+
model_files
2355+
)
23482356

23492357
@staticmethod
23502358
def upload_training_data(
@@ -2675,4 +2683,75 @@ def generate_variable_importance(
26752683
print(
26762684
f"{VARIMPORTANCES} was successfully written and saved to "
26772685
f"{Path(model_files) / VARIMPORTANCES}"
2686+
)
2687+
2688+
@classmethod
2689+
def generate_misc(
2690+
cls,
2691+
model_files: Union[str, Path, dict]
2692+
):
2693+
"""
2694+
Generates the dmcas_relativeimportance.json file, which is used to determine variable importance
2695+
2696+
Parameters
2697+
----------
2698+
conn
2699+
A SWAT connection used to connect to the user's CAS server
2700+
model_files : string, Path, or dict
2701+
Either the directory location of the model files (string or Path object), or
2702+
a dictionary containing the contents of all the model files.
2703+
"""
2704+
if isinstance(model_files, dict):
2705+
if ROC not in model_files:
2706+
raise RuntimeError(
2707+
"The ModelProperties.json file must be generated before the model card data "
2708+
"can be generated."
2709+
)
2710+
roc_table = model_files[ROC]
2711+
else:
2712+
if not Path.exists(Path(model_files) / ROC):
2713+
raise RuntimeError(
2714+
"The ModelProperties.json file must be generated before the model card data "
2715+
"can be generated."
2716+
)
2717+
with open(Path(model_files) / ROC, 'r') as roc_file:
2718+
roc_table = json.load(roc_file)
2719+
correct_text = ["CORRECT", "INCORRECT", "CORRECT", "INCORRECT"]
2720+
outcome_values = ['1', '0', '0', '1']
2721+
misc_data = list()
2722+
# Iterates through ROC table to get TRAIN, TEST, and VALIDATE data with a cutoff of .5
2723+
for i in range(50, 300, 100):
2724+
roc_data = roc_table['data'][i]['dataMap']
2725+
correctness_values = [roc_data['_TP_'], roc_data['_FP_'], roc_data['_TN_'], roc_data['_FN_']]
2726+
for (c_text, c_val, o_val) in zip(correct_text, correctness_values, outcome_values):
2727+
misc_data.append({
2728+
"CorrectText": c_text,
2729+
"Outcome": o_val,
2730+
"_Count_": c_val,
2731+
"_DataRole_": roc_data['_DataRole_'],
2732+
"_cutoffSource_": "Default",
2733+
"_cutoff_": "0.5"
2734+
})
2735+
2736+
json_template_path = (
2737+
Path(__file__).resolve().parent / f"template_files/{MISC}"
2738+
)
2739+
with open(json_template_path, 'r') as f:
2740+
misc_json = json.load(f)
2741+
misc_json['data'] = misc_data
2742+
2743+
if isinstance(model_files, dict):
2744+
model_files[MISC] = json.dumps(misc_json, indent=4, cls=NpEncoder)
2745+
if cls.notebook_output:
2746+
print(
2747+
f"{MISC} was successfully written and saved to "
2748+
f"model files dictionary."
2749+
)
2750+
else:
2751+
with open(Path(model_files) / MISC, 'w') as json_file:
2752+
json_file.write(json.dumps(misc_json, indent=4, cls=NpEncoder))
2753+
if cls.notebook_output:
2754+
print(
2755+
f"{MISC} was successfully written and saved to "
2756+
f"{Path(model_files) / MISC}"
26782757
)

0 commit comments

Comments
 (0)