Skip to content

Commit 850f054

Browse files
committed
fixed some bugs to allow for model card files to generate correctly.
1 parent 53cb8bb commit 850f054

File tree

1 file changed

+18
-16
lines changed

1 file changed

+18
-16
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,7 +2208,7 @@ def generate_model_card(
22082208
algorithm: str,
22092209
train_data: pd.DataFrame,
22102210
train_predictions: Union[pd.Series, list],
2211-
target_type: str = "Interval",
2211+
target_type: str = "interval",
22122212
target_value: Union[str, int, float, None] = None,
22132213
interval_vars: Optional[list] = [],
22142214
class_vars: Optional[list] = [],
@@ -2237,10 +2237,10 @@ def generate_model_card(
22372237
train_predictions : pandas.Series, list
22382238
List of predictions made by the model on the training data.
22392239
target_type : string
2240-
Type the model is targeting. Currently supports "Classification" and "Interval" types.
2240+
Type the model is targeting. Currently supports "classification" and "interval" types.
22412241
The default value is "Interval".
22422242
target_value : string, int, float, optional
2243-
Value the model is targeting for Classification models. This argument is not needed for
2243+
Value the model is targeting for classification models. This argument is not needed for
22442244
Interval models. The default value is None.
22452245
interval_vars : list, optional
22462246
A list of interval variables. The default value is an empty list.
@@ -2255,14 +2255,14 @@ def generate_model_card(
22552255
caslib: str, optional
22562256
The caslib the training data will be stored on. The default value is "Public"
22572257
"""
2258-
if not target_value and target_type == "Classification":
2258+
if not target_value and target_type == "classification":
22592259
raise RuntimeError(
2260-
"For the model card data to be properly generated on a Classification "
2260+
"For the model card data to be properly generated on a classification "
22612261
"model, a target value is required."
22622262
)
2263-
if target_type not in ["Classification", "Interval"]:
2263+
if target_type not in ["classification", "interval"]:
22642264
raise RuntimeError(
2265-
"Only Classification and Interval target types are currently accepted."
2265+
"Only classification and interval target types are currently accepted."
22662266
)
22672267
if selection_statistic not in cls.valid_params:
22682268
raise RuntimeError(
@@ -2396,10 +2396,10 @@ def generate_outcome_average(
23962396
Returns a dictionary with a key value pair that represents the outcome average.
23972397
"""
23982398
output_var = train_data.drop(input_variables, axis=1)
2399-
if target_type == "Classification":
2399+
if target_type == "classification":
24002400
value_counts = output_var[output_var.columns[0]].value_counts()
24012401
return {'eventPercentage': value_counts[target_value]/sum(value_counts)}
2402-
elif target_type == "Interval":
2402+
elif target_type == "interval":
24032403
return {'eventAverage': sum(value_counts[value_counts.columns[0]]) / len(value_counts)}
24042404

24052405
@staticmethod
@@ -2480,8 +2480,8 @@ def update_model_properties(
24802480
"The ModelProperties.json file must be generated before the model card data "
24812481
"can be generated."
24822482
)
2483-
for key, value in update_dict:
2484-
model_files[PROP][key] = value
2483+
for key in update_dict:
2484+
model_files[PROP][key] = update_dict[key]
24852485
else:
24862486
if not Path.exists(Path(model_files) / PROP):
24872487
raise RuntimeError(
@@ -2490,8 +2490,8 @@ def update_model_properties(
24902490
)
24912491
with open(Path(model_files) / PROP, 'r+') as properties_json:
24922492
model_properties = json.load(properties_json)
2493-
for key, value in update_dict:
2494-
model_properties[key] = value
2493+
for key in update_dict:
2494+
model_properties[key] = update_dict[key]
24952495
properties_json.seek(0)
24962496
properties_json.write(json.dumps(model_properties, indent=4, cls=NpEncoder))
24972497
properties_json.truncate()
@@ -2595,7 +2595,7 @@ def generate_variable_importance(
25952595
}
25962596
})
25972597
var_data = conn.dataPreprocess.transform(
2598-
table={"name": "test_data", "caslib": caslib},
2598+
table={"name": "train_data", "caslib": caslib},
25992599
requestPackages=request_packages,
26002600
evaluationStats=True,
26012601
percentileMaxIterations=10,
@@ -2623,7 +2623,10 @@ def generate_variable_importance(
26232623
},
26242624
"rowNumber": index+1
26252625
})
2626-
with open('./dmcas_relativeimportance.json', 'r') as f:
2626+
json_template_path = (
2627+
Path(__file__).resolve().parent / f"template_files/{VARIMPORTANCES}"
2628+
)
2629+
with open(json_template_path, 'r') as f:
26272630
relative_importance_json = json.load(f)
26282631
relative_importance_json['data'] = relative_importances
26292632

@@ -2641,5 +2644,4 @@ def generate_variable_importance(
26412644
print(
26422645
f"{VARIMPORTANCES} was successfully written and saved to "
26432646
f"{Path(model_files) / VARIMPORTANCES}"
2644-
26452647
)

0 commit comments

Comments
 (0)