Skip to content

Commit 05ddc81

Browse files
committed
completed adding batch scoring support to all return combinations
1 parent b838897 commit 05ddc81

File tree

1 file changed

+101
-34
lines changed

1 file changed

+101
-34
lines changed

src/sasctl/pzmm/write_score_code.py

Lines changed: 101 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1200,9 +1200,10 @@ def _no_targets_no_thresholds(
12001200
cls.score_code += (
12011201
f"{'':8}{metrics[i + 1]} = float(prediction[1][{i + 1}])\n"
12021202
)
1203+
metric_list = '"' + '","'.join(metrics) + '"'
12031204
cls.score_code += (
12041205
f"{'':4}else:\n"
1205-
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[{','.join(metrics)}])\n"
1206+
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[{metric_list}])\n"
12061207
f"{'':8}return output_table\n"
12071208
)
12081209
"""
@@ -1218,10 +1219,10 @@ def _no_targets_no_thresholds(
12181219
cls.score_code += f"{'':4}if input_array.shape[0] == 1"
12191220
for i in range(len(metrics)):
12201221
cls.score_code += f"{'':8}{metrics[i]} = prediction[{i}]\n"
1221-
#TODO: What is the use case for this? unsure of how batched scoring would work here
1222+
metric_list = '"' + '","'.join(metrics) + '"'
12221223
cls.score_code += (
12231224
f"{'':4}else:\n"
1224-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics)}])\n"
1225+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
12251226
f"{'':8}return output_table\n"
12261227
)
12271228
cls.score_code += f"\n{'':4}return {', '.join(metrics)}"
@@ -1378,11 +1379,12 @@ def _binary_target(
13781379
"score code should output the classification and probability for "
13791380
"the target event to occur."
13801381
)
1382+
metric_list = '"' + '","drop","'.join(metrics) + '"'
13811383
cls.score_code += (
13821384
f"{'':4}if input_array.shape[0] == 1:\n"
13831385
f"{'':8}return prediction[1][0], float(prediction[1][2])"
13841386
f"{'':4}else:\n"
1385-
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[{',drop,'.join(metrics)}])\n"
1387+
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[{metric_list}])\n"
13861388
f"{'':8}return output_table.drop('drop', axis=1)"
13871389
)
13881390
"""
@@ -1442,11 +1444,12 @@ def _binary_target(
14421444
"""
14431445
# Return classification and probability value
14441446
elif sum(returns) == 1 and len(returns) == 2:
1447+
metric_list = '"' + '","'.join(metrics) + '"'
14451448
cls.score_code += (
14461449
f"{'':4}if input_array.shape[0] == 1:\n"
14471450
f"{'':8}return prediction[0], prediction[1]\n"
14481451
f"{'':4}else:\n"
1449-
f"{'':8}return pd.DataFrame(prediction, columns=[{','.join(metrics)}])")
1452+
f"{'':8}return pd.DataFrame(prediction, columns=[{metric_list}])")
14501453
"""
14511454
return prediction[0], prediction[1]
14521455
"""
@@ -1459,22 +1462,24 @@ def _binary_target(
14591462
# Determine which return is the classification value
14601463
class_index = [i for i, x in enumerate(returns) if x][0]
14611464
if class_index == 0:
1465+
metric_list = '"' + '","'.join(metrics) + '","drop"'
14621466
cls.score_code += (
14631467
f"{'':4}if input_array.shape[0] == 1:\n"
14641468
f"{'':8}return prediction[0], prediction[1]\n"
14651469
f"{'':4}else:\n"
1466-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics)},drop])\n"
1470+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
14671471
f"{'':8}return output_table.drop('drop', axis=1)")
14681472

14691473
"""
14701474
return prediction[0], prediction[1]
14711475
"""
14721476
else:
1477+
metric_list = '"' + '","drop","'.join(metrics[::-1]) + '"'
14731478
cls.score_code += (
14741479
f"{'':4}if input_array.shape[0] == 1:\n"
14751480
f"{'':8}return prediction[{class_index}], prediction[0]\n"
14761481
f"{'':4}else:\n"
1477-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{',drop,'.join(metrics[::-1])}])\n"
1482+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
14781483
f"{'':8}return output_table.drop('drop', axis=1)")
14791484
"""
14801485
return prediction[2], prediction[0]
@@ -1483,12 +1488,13 @@ def _binary_target(
14831488
cls._invalid_predict_config()
14841489
elif len(metrics) == 3:
14851490
if h2o_model:
1491+
metric_list = '"' + '","'.join(metrics) + '"'
14861492
cls.score_code += (
14871493
f"{'':4}if input_array.shape[0] == 1:\n"
14881494
f"{'':8}return prediction[1][0], float(prediction[1][1]), "
14891495
f"float(prediction[1][2])\n"
14901496
f"{'':4}else:\n"
1491-
f"{'':8}return pd.DataFrame(prediction[1:], columns=[{','.join(metrics)}])"
1497+
f"{'':8}return pd.DataFrame(prediction[1:], columns=[{metric_list}])"
14921498
)
14931499
"""
14941500
return prediction[1][0], float(prediction[1][1]), float(prediction[1][2])
@@ -1525,6 +1531,7 @@ def _binary_target(
15251531
" types, the score code assumes the return order to be: "
15261532
"[classification, probability of event, probability of no event]."
15271533
)
1534+
metric_list = '"' + '","'.join(metrics[1:]) + '"'
15281535
cls.score_code += (
15291536
f"{'':4}if input_array.shape[0] == 1:\n"
15301537
f"{'':8}if prediction[0] > prediction[1]:\n"
@@ -1534,7 +1541,7 @@ def _binary_target(
15341541
f"{'':8}return {metrics[0]}, prediction[0], prediction[1]\n"
15351542
f"{'':4}else:\n"
15361543
f"{'':8}classifications = ['{target_values[0]}' if p[0] > p[1] else '{target_values[1]}' for p in prediction]\n"
1537-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics[1:])}])\n"
1544+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
15381545
f"{'':8}output_table = output_table.insert(loc = 0, column = '{metrics[0]}', value = classifications)\n"
15391546
f"{'':8}return output_table"
15401547
)
@@ -1551,25 +1558,27 @@ def _binary_target(
15511558
# Determine which return is the classification value
15521559
class_index = [i for i, x in enumerate(returns) if x][0]
15531560
if class_index == 0:
1561+
metric_list = '"' + '","'.join(metrics[:2]) + '"'
15541562
cls.score_code += (
15551563
f"{'':4}if input_array.shape[0] == 1:\n"
15561564
f"{'':8}return prediction[0], prediction[1], 1 - prediction[1]\n"
15571565
f"{'':4}else:\n"
15581566
f"{'':8}complement = [1 - p[1] for p in prediction]\n"
1559-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics[:2])}])\n"
1567+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
15601568
f"{'':8}output_table = output_table.insert(loc = 2, column = '{metrics[2]}', value = complement)\n"
15611569
f"{'':8}return output_table"
15621570
)
15631571
"""
15641572
return prediction[0], prediction[1], 1 - prediction[1]
15651573
"""
15661574
else:
1575+
metric_list = '"' + '","'.join(metrics[1::-1]) + '"'
15671576
cls.score_code += (
15681577
f"{'':4}if input_array.shape[0] == 1:\n"
15691578
f"{'':8}return prediction[1], prediction[0], 1 - prediction[0]\n"
15701579
f"{'':4}else:\n"
15711580
f"{'':8}complement = [1 - p[0] for p in prediction]\n"
1572-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics[1::-1])}])\n"
1581+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])\n"
15731582
f"{'':8}output_table = output_table.insert(loc = 2, column = '{metrics[2]}', value = complement)\n"
15741583
f"{'':8}return output_table"
15751584
)
@@ -1578,11 +1587,12 @@ def _binary_target(
15781587
"""
15791588
# Return all values from prediction method
15801589
elif sum(returns) == 1 and len(returns) == 3:
1590+
metric_list = '"' + '","'.join(metrics) + '"'
15811591
cls.score_code += (
15821592
f"{'':4}if input_array.shape[0] == 1:\n"
15831593
f"{'':8}return prediction[0], prediction[1], prediction[2]"
15841594
f"{'':4}else:\n"
1585-
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{','.join(metrics)}])"
1595+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])"
15861596
)
15871597
"""
15881598
return prediction[0], prediction[1], prediction[2]
@@ -1628,10 +1638,14 @@ def _nonbinary_targets(
16281638
# For h2o models with only one metric provided, return the classification
16291639
if h2o_model:
16301640
cls.score_code += (
1631-
f"{'':4}target_values = {target_values}\n{'':4}"
1632-
f"{metrics} = target_values[prediction[1][1:]."
1633-
f"index(max(prediction[1][1:]))]\n\n"
1634-
f"{'':4}return {metrics}"
1641+
f"{'':4}target_values = {target_values}\n"
1642+
f"{'':4}if input_array.shape[0] == 1:\n"
1643+
f"{'':8}{metrics} = target_values[prediction[1][1:]."
1644+
f"index(max(prediction[1][1:]))]\n"
1645+
f"{'':8}return {metrics}\n"
1646+
f"{'':4}else:\n"
1647+
f"{'':8}classifications = [target_values[np.argmax(p[1:])[0]] for p in prediction[1:]]\n"
1648+
f"{'':8}return pd.DataFrame({{'{metrics}': classifications}})"
16351649
)
16361650
"""
16371651
target_values = [1, 2, 3]
@@ -1641,16 +1655,25 @@ def _nonbinary_targets(
16411655
"""
16421656
# One return that is the classification
16431657
elif len(returns) == 1:
1644-
cls.score_code += f"{'':4}{metrics} = prediction\n\nreturn {metrics}"
1658+
cls.score_code += (
1659+
f"{'':4}if input_array.shape[0] == 1:\n"
1660+
f"{'':8}{metrics} = prediction\n\nreturn {metrics}\n"
1661+
f"{'':4}else:\n"
1662+
f"{'':8}return pd.DataFrame('{metrics}': prediction)"
1663+
)
16451664
"""
16461665
classification_variable = prediction
16471666
16481667
return classification_variable
16491668
"""
16501669
elif len(returns) == len(target_values):
16511670
cls.score_code += (
1652-
f"{'':4}target_values = {target_values}\n\n"
1653-
f"{'':4}return target_values[prediction.index(max(prediction))]"
1671+
f"{'':4}if input_array.shape[0] == 1:\n"
1672+
f"{'':8}target_values = {target_values}\n"
1673+
f"{'':8}return target_values[prediction.index(max(prediction))]\n"
1674+
f"{'':4}else:\n"
1675+
f"{'':8}classifications = [target_values[np.argmax(p)[0]] for p in prediction]\n"
1676+
f"{'':8}return pd.DataFrame('{metrics}': classifications)"
16541677
)
16551678
"""
16561679
target_values = [1, 2, 3]
@@ -1660,7 +1683,12 @@ def _nonbinary_targets(
16601683
elif len(returns) == (len(target_values) + 1):
16611684
# Determine which return is the classification value
16621685
class_index = [i for i, x in enumerate(returns) if x][0]
1663-
cls.score_code += f"{'':4}return prediction[{class_index}]"
1686+
cls.score_code += (
1687+
f"{'':4}if input_array.shape[0] == 1:\n"
1688+
f"{'':8}return prediction[{class_index}]"
1689+
f"{'':4}else:\n"
1690+
f"{'':8}return pd.DataFrame('{metrics}': [p[{class_index}] for p in prediction])"
1691+
)
16641692
"""
16651693
return prediction[1]
16661694
"""
@@ -1669,10 +1697,15 @@ def _nonbinary_targets(
16691697
elif len(metrics) == 2:
16701698
if h2o_model:
16711699
cls.score_code += (
1672-
f"{'':4}target_values = {target_values}\n{'':4}"
1673-
f"{metrics} = target_values[prediction[1][1:]."
1674-
f"index(max(prediction[1][1:]))]\n\n"
1675-
f"{'':4}return {metrics}, max(prediction[1][1:])"
1700+
f"{'':4}target_values = {target_values}\n"
1701+
f"{'':4}if input_array.shape[0] == 1:\n"
1702+
f"{'':8}{metrics} = target_values[prediction[1][1:]."
1703+
f"index(max(prediction[1][1:]))]\n"
1704+
f"{'':8}return {metrics}, max(prediction[1][1:])\n"
1705+
f"{'':4}else:\n"
1706+
f"{'':8}classifications = [target_values[np.argmax(p[1:])[0]] for p in prediction[1:]]\n"
1707+
f"{'':8}max_proba = [max(p[1:]) for p in prediction[1:]]\n"
1708+
f"{'':8}return pd.DataFrame({{'{metrics[0]}': classifications, '{metrics[1]}': max_proba}})"
16761709
)
16771710
"""
16781711
target_values = [1, 2, 3]
@@ -1682,9 +1715,14 @@ def _nonbinary_targets(
16821715
"""
16831716
elif len(returns) == len(target_values):
16841717
cls.score_code += (
1685-
f"{'':4}target_values = {target_values}\n\n"
1686-
f"{'':4}return target_values[prediction.index(max(prediction))], "
1687-
f"max(prediction)"
1718+
f"{'':4}target_values = {target_values}\n"
1719+
f"{'':4}if input_array.shape[0] == 1:\n"
1720+
f"{'':8}return target_values[prediction.index(max(prediction))], "
1721+
f"max(prediction)\n"
1722+
f"{'':4}else:\n"
1723+
f"{'':8}classifications = [target_values[np.argmax(p)[0]] for p in prediction]\n"
1724+
f"{'':8}max_proba = [max(p) for p in prediction]\n"
1725+
f"{'':8}return pd.DataFrame({{'{metrics[0]}': classifications, '{metrics[1]}': max_proba}})"
16881726
)
16891727
"""
16901728
target_values = [1, 2, 3]
@@ -1695,8 +1733,13 @@ def _nonbinary_targets(
16951733
# Determine which return is the classification value
16961734
class_index = [i for i, x in enumerate(returns) if x][0]
16971735
cls.score_code += (
1698-
f"{'':4}return prediction[{class_index}], "
1699-
f"max(prediction[:{class_index}] + prediction[{class_index + 1}:])"
1736+
f"{'':4}if input_array.shape[0] == 1:\n"
1737+
f"{'':8}return prediction[{class_index}], "
1738+
f"max(prediction[:{class_index}] + prediction[{class_index + 1}:])\n"
1739+
f"{'':4}else:\n"
1740+
f"{'':8}max_proba = [max(p[:{class_index}] + p[{class_index + 1}:]) for p in prediction]\n"
1741+
f"{'':8}classifications = [p[{class_index}] for p in prediction]\n"
1742+
f"{'':8}return pd.DataFrame({{'{metrics[0]}': classifications, '{metrics[1]}': nax_proba}})"
17001743
)
17011744
"""
17021745
return prediction[1], max(prediction[:1] + prediction[2:])
@@ -1707,13 +1750,26 @@ def _nonbinary_targets(
17071750
if h2o_model:
17081751
if len(metrics) == len(target_values):
17091752
h2o_returns = [f"prediction[1][{i+1}]" for i in range(len(metrics))]
1710-
cls.score_code += f"{'':4}return {', '.join(h2o_returns)}"
1753+
metric_list = '"' + '","'.join(metrics) + '"'
1754+
cls.score_code += (
1755+
f"{'':4}if input_array.shape[0] == 1:\n"
1756+
f"{'':8}return {', '.join(h2o_returns)}\n"
1757+
f"{'':4}else:\n"
1758+
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[{metric_list}])"
1759+
f"{'':8}output_table = output_table.drop('{metrics[0]}', axis=1)"
1760+
f"{'':8}return output_table"
1761+
)
17111762
"""
17121763
return prediction[1][1], prediction[1][2], prediction[1][3]
17131764
"""
17141765
elif len(metrics) == (len(target_values) + 1):
17151766
h2o_returns = [f"prediction[1][{i}]" for i in range(len(metrics))]
1716-
cls.score_code += f"{'':4}return {', '.join(h2o_returns)}"
1767+
cls.score_code += (
1768+
f"{'':4}if input_array.shape[0] == 1:\n"
1769+
f"{'':8}return {', '.join(h2o_returns)}\n"
1770+
f"{'':4}else:\n"
1771+
f"{'':8}output_table = pd.DataFrame(prediction[1:], columns=[])"
1772+
)
17171773
"""
17181774
return prediction[1][0], prediction[1][1], prediction[1][2]
17191775
"""
@@ -1724,7 +1780,13 @@ def _nonbinary_targets(
17241780
and sum(returns) == 1
17251781
):
17261782
proba_returns = [f"prediction[{i}]" for i in range(len(returns))]
1727-
cls.score_code += f"{'':4}return {', '.join(proba_returns)}"
1783+
metric_list = '"' + '","'.join(metrics) + '"'
1784+
cls.score_code += (
1785+
f"{'':4}if input_array.shape[0] == 1:\n"
1786+
f"{'':8}return {', '.join(proba_returns)}\n"
1787+
f"{'':4}else:\n"
1788+
f"{'':8}output_table = pd.DataFrame(prediction, columns=[{metric_list}])"
1789+
)
17281790
"""
17291791
return prediction[0], prediction[1], prediction[2]
17301792
"""
@@ -1734,8 +1796,13 @@ def _nonbinary_targets(
17341796
proba_returns = [f"prediction[{i}]" for i in range(len(returns))]
17351797
cls.score_code += (
17361798
f"{'':4}target_values = {target_values}\n\n"
1737-
f"{'':4}return target_values[prediction.index(max(prediction))], "
1738-
f"{', '.join(proba_returns)}"
1799+
f"{'':4}if input_array.shape[0] == 1:\n"
1800+
f"{'':8}return target_values[prediction.index(max(prediction))], "
1801+
f"{', '.join(proba_returns)}\n"
1802+
f"{'':8}classifications = [target_values[np.argmax(p)[0]] for p in prediction]\n"
1803+
f"{'':8}output_table = pd.DataFrame(prediction, columns={metrics[1:]})"
1804+
f"{'':8}output_table = output_table.insert(loc=0, column={metrics[0]}, data=classifications)"
1805+
f"{'':8}return output_table"
17391806
)
17401807
"""
17411808
target_values = [1, 2, 3]

0 commit comments

Comments
 (0)