Skip to content

Commit ee2cb5f

Browse files
committed
black reformatting
1 parent 8f26fd2 commit ee2cb5f

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/sasctl/pzmm/write_json_files.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,7 +1393,7 @@ def check_for_data(
13931393
def stat_dataset_to_dataframe(
13941394
data: Union[DataFrame, List[list], Type["numpy.array"]],
13951395
target_value: Union[str, int, float] = None,
1396-
target_type: str = 'classification'
1396+
target_type: str = "classification",
13971397
) -> DataFrame:
13981398
"""
13991399
Convert the user supplied statistical dataset from either a pandas DataFrame,
@@ -1441,14 +1441,14 @@ def stat_dataset_to_dataframe(
14411441
if isinstance(data, pd.DataFrame):
14421442
if len(data.columns) == 2:
14431443
data.columns = ["actual", "predict"]
1444-
if target_type == 'classification':
1444+
if target_type == "classification":
14451445
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14461446
elif len(data.columns) == 3:
14471447
data.columns = ["actual", "predict", "predict_proba"]
14481448
elif isinstance(data, list):
14491449
if len(data) == 2:
14501450
data = pd.DataFrame({"actual": data[0], "predict": data[1]})
1451-
if target_type == 'classification':
1451+
if target_type == "classification":
14521452
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14531453
elif len(data) == 3:
14541454
data = pd.DataFrame(
@@ -1461,7 +1461,7 @@ def stat_dataset_to_dataframe(
14611461
elif isinstance(data, np.ndarray):
14621462
if len(data) == 2:
14631463
data = pd.DataFrame({"actual": data[0, :], "predict": data[1, :]})
1464-
if target_type == 'classification':
1464+
if target_type == "classification":
14651465
data["predict_proba"] = data["predict"].gt(target_value).astype(int)
14661466
elif len(data) == 3:
14671467
data = pd.DataFrame(
@@ -2372,7 +2372,7 @@ def generate_model_card(
23722372
)
23732373

23742374
# Generates dmcas_misc.json file
2375-
if target_type == 'classification':
2375+
if target_type == "classification":
23762376
cls.generate_misc(model_files)
23772377

23782378
@staticmethod
@@ -2782,7 +2782,11 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27822782
roc_data["_FN_"],
27832783
]
27842784
for c_text, c_val, o_val, t_txt, t_val in zip(
2785-
correct_text, correctness_values, outcome_values, target_texts, target_values
2785+
correct_text,
2786+
correctness_values,
2787+
outcome_values,
2788+
target_texts,
2789+
target_values,
27862790
):
27872791
misc_data.append(
27882792
{
@@ -2794,7 +2798,7 @@ def generate_misc(cls, model_files: Union[str, Path, dict]):
27942798
"_cutoffSource_": "Default",
27952799
"_cutoff_": "0.5",
27962800
"TargetText": t_txt,
2797-
"Target": t_val
2801+
"Target": t_val,
27982802
},
27992803
"rowNumber": len(misc_data) + 1,
28002804
}

0 commit comments

Comments
 (0)