Skip to content

Commit a7b49ac

Browse files
committed
updates to write_json_files to work better with prediction models
1 parent f279234 commit a7b49ac

File tree

1 file changed

+13
-8
lines changed

1 file changed

+13
-8
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1260,7 +1260,7 @@ def calculate_model_statistics(
12601260
if not partition:
12611261
continue
12621262

1263-
data = cls.stat_dataset_to_dataframe(data, target_value)
1263+
data = cls.stat_dataset_to_dataframe(data, target_value, target_type)
12641264

12651265
conn.upload(
12661266
data,
@@ -1392,6 +1392,7 @@ def check_for_data(
13921392
def stat_dataset_to_dataframe(
13931393
data: Union[DataFrame, List[list], Type["numpy.array"]],
13941394
target_value: Union[str, int, float] = None,
1395+
target_type: str = 'classification'
13951396
) -> DataFrame:
13961397
"""
13971398
Convert the user supplied statistical dataset from either a pandas DataFrame,
@@ -1439,13 +1440,15 @@ def stat_dataset_to_dataframe(
14391440
if isinstance(data, pd.DataFrame):
14401441
if len(data.columns) == 2:
14411442
data.columns = ["actual", "predict"]
1442-
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
1443+
if target_type == 'classification':
1444+
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14431445
elif len(data.columns) == 3:
14441446
data.columns = ["actual", "predict", "predict_proba"]
14451447
elif isinstance(data, list):
14461448
if len(data) == 2:
14471449
data = pd.DataFrame({"actual": data[0], "predict": data[1]})
1448-
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
1450+
if target_type == 'classification':
1451+
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14491452
elif len(data) == 3:
14501453
data = pd.DataFrame(
14511454
{
@@ -1457,7 +1460,8 @@ def stat_dataset_to_dataframe(
14571460
elif isinstance(data, np.ndarray):
14581461
if len(data) == 2:
14591462
data = pd.DataFrame({"actual": data[0, :], "predict": data[1, :]})
1460-
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
1463+
if target_type == 'classification':
1464+
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14611465
elif len(data) == 3:
14621466
data = pd.DataFrame(
14631467
{"actual": data[0], "predict": data[1], "predict_proba": data[2]}
@@ -2366,7 +2370,8 @@ def generate_model_card(
23662370
)
23672371

23682372
# Generates dmcas_misc.json file
2369-
cls.generate_misc(model_files)
2373+
if target_type == 'classification':
2374+
cls.generate_misc(model_files)
23702375

23712376
@staticmethod
23722377
def upload_training_data(
@@ -2617,7 +2622,7 @@ def generate_variable_importance(
26172622
if target_type == "classification":
26182623
method = "DTREE"
26192624
treeCrit = "Entropy"
2620-
elif target_type == "interval":
2625+
elif target_type == "prediction":
26212626
method = "RTREE"
26222627
treeCrit = "RSS"
26232628
else:
@@ -2743,14 +2748,14 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27432748
if isinstance(model_files, dict):
27442749
if ROC not in model_files:
27452750
raise RuntimeError(
2746-
"The ModelProperties.json file must be generated before the model card data "
2751+
"The dmcas_roc.json file must be generated before the model card data "
27472752
"can be generated."
27482753
)
27492754
roc_table = model_files[ROC]
27502755
else:
27512756
if not Path.exists(Path(model_files) / ROC):
27522757
raise RuntimeError(
2753-
"The ModelProperties.json file must be generated before the model card data "
2758+
"The dmcas_roc.json file must be generated before the model card data "
27542759
"can be generated."
27552760
)
27562761
with open(Path(model_files) / ROC, "r") as roc_file:

0 commit comments

Comments
 (0)