Skip to content

Commit 5d577d8

Browse files
committed
Updated unit tests to correctly test score code changes
1 parent e72132f commit 5d577d8

File tree

2 files changed

+80
-63
lines changed

2 files changed

+80
-63
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ def score(var1, var2, var3, var4):
282282
predict_method[1],
283283
target_values=target_values,
284284
predict_threshold=predict_threshold,
285-
target_index=target_index
285+
target_index=target_index,
286286
)
287287

288288
if missing_values:
@@ -1090,7 +1090,7 @@ def _predictions_to_metrics(
10901090
target_values: Optional[List[str]] = None,
10911091
predict_threshold: Optional[float] = None,
10921092
h2o_model: Optional[bool] = False,
1093-
target_index: Optional[int] = 1
1093+
target_index: Optional[int] = 1,
10941094
) -> None:
10951095
"""
10961096
Using the provided arguments, write in to the score code the method for handling
@@ -1136,7 +1136,12 @@ def _predictions_to_metrics(
11361136
# Binary classification model
11371137
elif len(target_values) == 2:
11381138
cls._binary_target(
1139-
metrics, target_values, predict_returns, predict_threshold, target_index, h2o_model
1139+
metrics,
1140+
target_values,
1141+
predict_returns,
1142+
predict_threshold,
1143+
target_index,
1144+
h2o_model,
11401145
)
11411146
# Multiclass classification model
11421147
elif len(target_values) > 2:
@@ -1187,10 +1192,11 @@ def _no_targets_no_thresholds(
11871192
)
11881193
"""
11891194
if input_array.shape[0] == 1:
1190-
Classification = prediction[0]
1195+
Classification = prediction[1][0]
11911196
return Classification
11921197
else:
1193-
output_table = pd.DataFrame({'Classification': prediction})
1198+
output_table = prediction.drop(prediction.columns[1:], axis=1)
1199+
output_table.columns = ['Classification']
11941200
return output_table
11951201
"""
11961202
else:
@@ -1203,11 +1209,10 @@ def _no_targets_no_thresholds(
12031209
)
12041210
"""
12051211
if input_array.shape[0] == 1:
1206-
Classification = prediction[1][0]
1212+
Classification = prediction[0][0]
12071213
return Classification
12081214
else:
1209-
output_table = prediction.drop(prediction.columns[1:], axis=1)
1210-
output_table.columns = ['Classification']
1215+
output_table = pd.DataFrame({'Classification': prediction})
12111216
return output_table
12121217
"""
12131218
else:
@@ -1271,7 +1276,7 @@ def _binary_target(
12711276
returns: List[Any],
12721277
threshold: Optional[float] = None,
12731278
h2o_model: Optional[bool] = None,
1274-
target_index: Optional[int] = 1
1279+
target_index: Optional[int] = 1,
12751280
) -> None:
12761281
"""
12771282
Handle binary model prediction outputs.
@@ -1317,13 +1322,13 @@ def _binary_target(
13171322
if h2o_model:
13181323
cls.score_code += (
13191324
f"{'':4}if input_array.shape[0] == 1:\n"
1320-
f"{'':8}if prediction[1][{target_index} + 1] > {threshold}:\n"
1325+
f"{'':8}if prediction[1][{target_index+1}] > {threshold}:\n"
13211326
f"{'':12}{metrics} = \"{target_values[target_index]}\"\n"
13221327
f"{'':8}else:\n"
13231328
f"{'':12}{metrics} = \"{target_values[abs(target_index-1)]}\"\n"
13241329
f"{'':8}return {metrics}\n"
13251330
f"{'':4}else:\n"
1326-
f"{'':8}output_table = pd.DataFrame({{'{metrics}': np.where(prediction[prediction.columns[{target_index+1}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')}})"
1331+
f"{'':8}output_table = pd.DataFrame({{'{metrics}': np.where(prediction[prediction.columns[{target_index+1}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')}})\n"
13271332
f"{'':8}return output_table"
13281333
)
13291334
"""
@@ -1385,8 +1390,8 @@ def _binary_target(
13851390
f"{'':8}return {metrics}\n"
13861391
f"{'':4}else:\n"
13871392
f"{'':8}target_values = {target_values}\n"
1388-
f"{'':8}prediction = pd.DataFrame(prediction)"
1389-
f"{'':8}output_table = pd.DataFrame({{'{metrics}': np.where(prediction[prediction.columns[{target_index}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')}})"
1393+
f"{'':8}prediction = pd.DataFrame(prediction)\n"
1394+
f"{'':8}output_table = pd.DataFrame({{'{metrics}': np.where(prediction[prediction.columns[{target_index}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')}})\n"
13901395
f"{'':8}return output_table"
13911396
)
13921397
"""
@@ -1404,6 +1409,7 @@ def _binary_target(
14041409
"""
14051410
# Classification and probability returned; return classification value
14061411
elif len(returns) > 1 and sum(returns) == 1:
1412+
# TODO: Either figure out how to handle threshold or add warning
14071413
# Determine which return is the classification value
14081414
class_index = [i for i, x in enumerate(returns) if x][0]
14091415
cls.score_code += (
@@ -1424,7 +1430,6 @@ def _binary_target(
14241430
else:
14251431
cls._invalid_predict_config()
14261432
elif len(metrics) == 2:
1427-
# TODO: change to align with other cases and assign target_values to classification column
14281433
# H2O models with two metrics are assumed to be classification + probability
14291434
if h2o_model:
14301435
warn(
@@ -1434,10 +1439,16 @@ def _binary_target(
14341439
)
14351440
cls.score_code += (
14361441
f"{'':4}if input_array.shape[0] == 1:\n"
1437-
f"{'':8}return prediction[1][0], float(prediction[1][2])\n"
1442+
f"{'':8}if prediction[1][{target_index+1}] > {threshold}:\n"
1443+
f"{'':12}{metrics[0]} = '{target_values[target_index]}'\n"
1444+
f"{'':8}else:\n"
1445+
f"{'':12}{metrics[0]} = '{target_values[abs(target_index-1)]}'\n"
1446+
f"{'':8}return {metrics[0]}, float(prediction[1][{target_index+1}])\n"
14381447
f"{'':4}else:\n"
14391448
f"{'':8}output_table = prediction.drop(prediction.columns[{abs(target_index-1)+1}], axis=1)\n"
1449+
f"{'':8}classifications = np.where(prediction[prediction.columns[{target_index+1}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')\n"
14401450
f"{'':8}output_table.columns = {metrics}\n"
1451+
f"{'':8}output_table['{metrics[0]}'] = classifications\n"
14411452
f"{'':8}return output_table"
14421453
)
14431454
"""
@@ -1494,7 +1505,7 @@ def _binary_target(
14941505
f"{'':8}return {metrics[0]}, prediction[0][{target_index}]\n"
14951506
f"{'':4}else:\n"
14961507
f"{'':8}df = pd.DataFrame(prediction)\n"
1497-
f"{'':8}proba = df[0]\n"
1508+
f"{'':8}proba = df[{target_index}]\n"
14981509
f"{'':8}classifications = np.where(df[{target_index}] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')\n"
14991510
f"{'':8}return pd.DataFrame({{'{metrics[0]}': classifications, '{metrics[1]}': proba}})"
15001511
)
@@ -1511,6 +1522,7 @@ def _binary_target(
15111522
classifications = np.where(df[1] > .5, 'B', 'A')
15121523
return pd.DataFrame({'Classification': classifications, 'Probability': proba})
15131524
"""
1525+
# TODO: Potentially add threshold
15141526
# Return classification and probability value
15151527
elif sum(returns) == 1 and len(returns) == 2:
15161528
cls.score_code += (
@@ -1536,10 +1548,11 @@ def _binary_target(
15361548
if class_index == 0:
15371549
cls.score_code += (
15381550
f"{'':4}if input_array.shape[0] == 1:\n"
1539-
f"{'':8}return prediction[0][0], prediction[0][{target_index} + 1]\n"
1551+
f"{'':8}return prediction[0][0], prediction[0][{target_index+1}]\n"
15401552
f"{'':4}else:\n"
1553+
f"{'':8}prediction = pd.DataFrame(prediction)\n"
15411554
f"{'':8}output_table = prediction.drop(prediction.columns[{abs(target_index-1)+1}], axis=1)\n"
1542-
f"{'':8}output_table.columns = {metrics}"
1555+
f"{'':8}output_table.columns = {metrics}\n"
15431556
f"{'':8}return output_table"
15441557
)
15451558

@@ -1556,9 +1569,10 @@ def _binary_target(
15561569
f"{'':4}if input_array.shape[0] == 1:\n"
15571570
f"{'':8}return prediction[0][{class_index}], prediction[0][{target_index}]\n"
15581571
f"{'':4}else:\n"
1572+
f"{'':8}prediction = pd.DataFrame(prediction)\n"
15591573
f"{'':8}output_table = prediction.drop(prediction.columns[{abs(target_index-1)}], axis=1)\n"
15601574
f"{'':8}output_table = output_table[output_table.columns[::-1]]\n"
1561-
f"{'':8}output_table.columns = {metrics}"
1575+
f"{'':8}output_table.columns = {metrics}\n"
15621576
f"{'':8}return output_table"
15631577
)
15641578
"""
@@ -1622,6 +1636,7 @@ def _binary_target(
16221636
return output_table
16231637
"""
16241638
elif sum(returns) == 0 and len(returns) == 2:
1639+
# TODO: Make decision on whether ordering should follow given pattern or reflect input ordering
16251640
warn(
16261641
"Due to the ambiguity of the provided metrics and prediction return"
16271642
" types, the score code assumes the return order to be: "
@@ -1638,8 +1653,8 @@ def _binary_target(
16381653
f"{'':8}return {metrics[0]}, prediction[0][{target_index}], prediction[0][{abs(target_index-1)}]\n"
16391654
f"{'':4}else:\n"
16401655
f"{'':8}output_table = pd.DataFrame(prediction, columns={metrics[1:]})\n"
1641-
f"{'':8}classifications = np.where(prediction[prediction.columns[{target_index}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')"
1642-
f"{'':8}output_table.insert(loc=0, column='{metrics[0]}', value=classifications)"
1656+
f"{'':8}classifications = np.where(output_table[output_table.columns[{target_index}]] > {threshold}, '{target_values[target_index]}', '{target_values[abs(target_index-1)]}')\n"
1657+
f"{'':8}output_table.insert(loc=0, column='{metrics[0]}', value=classifications)\n"
16431658
f"{'':8}return output_table"
16441659
)
16451660
"""

0 commit comments

Comments
 (0)