Skip to content

Commit 1262201

Browse files
authored
Merge pull request #210 from sassoftware/assessment_fix
2 parents 9d3bcab + 294ed96 commit 1262201

File tree

3 files changed

+19
-23
lines changed

3 files changed

+19
-23
lines changed

examples/pzmm_binary_classification_model_import.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -717,7 +717,7 @@
717717
},
718718
{
719719
"cell_type": "code",
720-
"execution_count": 17,
720+
"execution_count": null,
721721
"metadata": {
722722
"Collapsed": "false"
723723
},
@@ -751,8 +751,7 @@
751751
" \n",
752752
" # Calculate the model statistics, ROC chart, and Lift chart; then write to json files\n",
753753
" pzmm.JSONFiles.calculate_model_statistics(\n",
754-
" target_value=1, \n",
755-
" prob_value=0.5, \n",
754+
" target_value=1,\n",
756755
" train_data=train_data, \n",
757756
" test_data=test_data, \n",
758757
" json_path=path\n",

examples/pzmm_generate_complete_model_card.ipynb

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -874,8 +874,7 @@
874874
"source": [
875875
"# Step 10: Write model statistics files\n",
876876
"pzmm.JSONFiles.calculate_model_statistics(\n",
877-
" target_value=1, \n",
878-
" prob_value=0.5, \n",
877+
" target_value=1,\n",
879878
" train_data=train_scored[[target, ti, t1]], \n",
880879
" test_data=test_scored[[target, ti, t1]],\n",
881880
" validate_data=test_scored[[target, ti, t1]],\n",
@@ -1786,7 +1785,7 @@
17861785
],
17871786
"metadata": {
17881787
"kernelspec": {
1789-
"display_name": "Python 3",
1788+
"display_name": ".venv",
17901789
"language": "python",
17911790
"name": "python3"
17921791
},

src/sasctl/pzmm/write_json_files.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ def write_var_json(
165165

166166
@staticmethod
167167
def generate_variable_properties(
168-
input_data: Union[DataFrame, Series]
168+
input_data: Union[DataFrame, Series],
169169
) -> List[dict]:
170170
"""
171171
Generate a list of dictionaries of variable properties given an input dataframe.
@@ -1192,7 +1192,6 @@ def bias_dataframes_to_json(
11921192
def calculate_model_statistics(
11931193
cls,
11941194
target_value: Union[str, int, float],
1195-
prob_value: Union[int, float, None] = None,
11961195
validate_data: Union[DataFrame, List[list], Type["numpy.ndarray"]] = None,
11971196
train_data: Union[DataFrame, List[list], Type["numpy.ndarray"]] = None,
11981197
test_data: Union[DataFrame, List[list], Type["numpy.ndarray"]] = None,
@@ -1211,8 +1210,7 @@ def calculate_model_statistics(
12111210
Datasets must contain the actual and predicted values and may optionally contain
12121211
the predicted probabilities. If no probabilities are provided, a dummy
12131212
probability dataset is generated based on the predicted values and normalized by
1214-
the target value. If a probability threshold value is not provided, the
1215-
threshold value is set at 0.5.
1213+
the target value.
12161214
12171215
Datasets can be provided in the following forms, with the assumption that data
12181216
is ordered as `actual`, `predict`, and `probability` respectively:
@@ -1229,9 +1227,6 @@ def calculate_model_statistics(
12291227
----------
12301228
target_value : str, int, or float
12311229
Target event value for model prediction events.
1232-
prob_value : int or float, optional
1233-
The threshold value for model predictions to indicate an event occurred. The
1234-
default value is 0.5.
12351230
validate_data : pandas.DataFrame, list of list, or numpy.ndarray, optional
12361231
Dataset pertaining to the validation data. The default value is None.
12371232
train_data : pandas.DataFrame, list of list, or numpy.ndarray, optional
@@ -1284,30 +1279,33 @@ def calculate_model_statistics(
12841279
continue
12851280

12861281
data = cls.stat_dataset_to_dataframe(data, target_value, target_type)
1282+
data["predict_proba2"] = 1 - data["predict_proba"]
12871283

12881284
conn.upload(
12891285
data,
1290-
casout={"name": "assess_dataset", "replace": True, "caslib": "Public"},
1286+
casout={"caslib": "Public", "name": "assess_dataset", "replace": True},
12911287
)
1288+
12921289
if target_type == "classification":
12931290
conn.percentile.assess(
12941291
table={"name": "assess_dataset", "caslib": "Public"},
1295-
response="predict",
1296-
pVar="predict_proba",
1297-
event=str(target_value),
1298-
pEvent=str(prob_value) if prob_value else str(0.5),
1299-
inputs="actual",
1292+
inputs="predict_proba",
1293+
response="actual",
1294+
event="1",
1295+
pvar="predict_proba2",
1296+
pevent="0",
1297+
includeLift=True,
13001298
fitStatOut={"name": "FitStat", "replace": True, "caslib": "Public"},
13011299
rocOut={"name": "ROC", "replace": True, "caslib": "Public"},
13021300
casout={"name": "Lift", "replace": True, "caslib": "Public"},
13031301
)
13041302
else:
13051303
conn.percentile.assess(
13061304
table={"name": "assess_dataset", "caslib": "Public"},
1307-
response="predict",
1308-
inputs="actual",
1309-
fitStatOut={"name": "FitStat", "replace": True, "caslib": "Public"},
1310-
casout={"name": "Lift", "replace": True, "caslib": "Public"},
1305+
response="actual",
1306+
inputs="predict",
1307+
fitStatOut={"caslib": "Public", "name": "FitStat", "replace": True},
1308+
casout={"caslib": "Public", "name": "Lift", "replace": True},
13111309
)
13121310

13131311
fitstat_dict = (

0 commit comments

Comments
 (0)