Skip to content

Commit 8f26fd2

Browse files
committed
updated write_json_files to fix model card issues
1 parent a7b49ac commit 8f26fd2

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -965,6 +965,7 @@ def format_max_differences(
965965
maxdiff_df = maxdiff_df.rename(
966966
columns={"Value": "maxdiff", "Base": "BASE", "Compare": "COMPARE"}
967967
)
968+
maxdiff_df["maxdiff"] = maxdiff_df["maxdiff"].apply(str)
968969

969970
maxdiff_df["VLABEL"] = ""
970971
maxdiff_df["_DATAROLE_"] = datarole
@@ -2257,6 +2258,7 @@ def generate_model_card(
22572258
interval_vars: Optional[list] = [],
22582259
class_vars: Optional[list] = [],
22592260
selection_statistic: str = None,
2261+
training_table_name: str = None,
22602262
server: str = "cas-shared-default",
22612263
caslib: str = "Public",
22622264
):
@@ -2336,7 +2338,7 @@ def generate_model_card(
23362338

23372339
# Upload training table to CAS. The location of the training table is returned.
23382340
training_table = cls.upload_training_data(
2339-
conn, model_prefix, train_data, server, caslib
2341+
conn, model_prefix, train_data, training_table_name, server, caslib
23402342
)
23412343

23422344
# Generates the event percentage for Classification targets, and the event average
@@ -2378,6 +2380,7 @@ def upload_training_data(
23782380
conn,
23792381
model_prefix: str,
23802382
train_data: pd.DataFrame,
2383+
train_data_name: str,
23812384
server: str = "cas-shared-default",
23822385
caslib: str = "Public",
23832386
):
@@ -2404,15 +2407,18 @@ def upload_training_data(
24042407
Returns a string that represents the location of the training table within CAS.
24052408
"""
24062409
# Upload raw training data to caslib so that data can be analyzed
2407-
train_data_name = model_prefix + "_train_data"
2410+
if not train_data_name:
2411+
train_data_name = model_prefix + "_train_data"
24082412
upload_train_data = conn.upload(
24092413
train_data, casout={"name": train_data_name, "caslib": caslib}, promote=True
24102414
)
24112415

24122416
if upload_train_data.status is not None:
2413-
raise RuntimeError(
2414-
f"A table with the name {train_data_name} already exists in the specified caslib. Please "
2415-
"either delete/rename the old table or give a new name to the current table."
2417+
# raise RuntimeError(
2418+
warnings.warn(
2419+
f"A table with the name {train_data_name} already exists in the specified caslib. If this "
2420+
f"is not intentional, please either rename the training data file or remove the duplicate from "
2421+
f"the caslib."
24162422
)
24172423

24182424
return server + "/" + caslib + "/" + train_data_name.upper()
@@ -2762,6 +2768,9 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27622768
roc_table = json.load(roc_file)
27632769
correct_text = ["CORRECT", "INCORRECT", "CORRECT", "INCORRECT"]
27642770
outcome_values = ["1", "0", "0", "1"]
2771+
target_texts = ["Event", "Event", "NEvent", "NEvent"]
2772+
target_values = ["1", "1", "0", "0"]
2773+
27652774
misc_data = list()
27662775
# Iterates through ROC table to get TRAIN, TEST, and VALIDATE data with a cutoff of .5
27672776
for i in range(50, 300, 100):
@@ -2772,8 +2781,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27722781
roc_data["_TN_"],
27732782
roc_data["_FN_"],
27742783
]
2775-
for c_text, c_val, o_val in zip(
2776-
correct_text, correctness_values, outcome_values
2784+
for c_text, c_val, o_val, t_txt, t_val in zip(
2785+
correct_text, correctness_values, outcome_values, target_texts, target_values
27772786
):
27782787
misc_data.append(
27792788
{
@@ -2784,6 +2793,8 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27842793
"_DataRole_": roc_data["_DataRole_"],
27852794
"_cutoffSource_": "Default",
27862795
"_cutoff_": "0.5",
2796+
"TargetText": t_txt,
2797+
"Target": t_val
27872798
},
27882799
"rowNumber": len(misc_data) + 1,
27892800
}

0 commit comments

Comments
 (0)