Skip to content

Commit cdb3a52

Browse files
authored
Merge pull request #188 from sassoftware/model_cards
Model cards
2 parents 28a3e22 + fa499fc commit cdb3a52

File tree

4 files changed

+854
-40
lines changed

4 files changed

+854
-40
lines changed

examples/pzmm_binary_classification_model_import.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@
740740
],
741741
"source": [
742742
"import getpass\n",
743-
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path):\n",
743+
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path, prefix):\n",
744744
" # Calculate train predictions\n",
745745
" train_predict = model.predict(x_train)\n",
746746
" train_proba = model.predict_proba(x_train)\n",
@@ -757,6 +757,20 @@
757757
" test_data=test_data, \n",
758758
" json_path=path\n",
759759
" )\n",
760+
"\n",
761+
" full_training_data = pd.concat([y_train.reset_index(drop=True), x_train.reset_index(drop=True)], axis=1)\n",
762+
"\n",
763+
" pzmm.JSONFiles.generate_model_card(\n",
764+
" model_prefix=prefix,\n",
765+
" model_files = path,\n",
766+
" algorithm = str(type(model).__name__),\n",
767+
" train_data = full_training_data,\n",
768+
" train_predictions=train_predict,\n",
769+
" target_type='classification',\n",
770+
" target_value=1,\n",
771+
" interval_vars=predictor_columns,\n",
772+
" selection_statistic='_RASE_',\n",
773+
" )\n",
760774
" \n",
761775
"username = getpass.getpass()\n",
762776
"password = getpass.getpass()\n",
@@ -766,8 +780,8 @@
766780
"\n",
767781
"test_predict = [y_dtc_predict, y_rfc_predict, y_gbc_predict]\n",
768782
"test_proba = [y_dtc_proba, y_rfc_proba, y_gbc_proba]\n",
769-
"for (mod, pred, proba, path) in zip(model, test_predict, test_proba, zip_folder):\n",
770-
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path)"
783+
"for (mod, pred, proba, path, prefix) in zip(model, test_predict, test_proba, zip_folder, model_prefix):\n",
784+
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path, prefix)"
771785
]
772786
},
773787
{
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
{
2+
"creationTimeStamp" : "0001-01-01T00:00:00Z",
3+
"modifiedTimeStamp" : "0001-01-01T00:00:00Z",
4+
"revision" : 0,
5+
"name" : "dmcas_relativeimportance",
6+
"version" : 0,
7+
"order" : 0,
8+
"parameterMap" : {
9+
"LABEL" : {
10+
"label" : "Variable Label",
11+
"length" : 256,
12+
"order" : 1,
13+
"parameter" : "LABEL",
14+
"preformatted" : false,
15+
"type" : "char",
16+
"values" : [ "LABEL" ]
17+
},
18+
"LEVEL" : {
19+
"label" : "Variable Level",
20+
"length" : 10,
21+
"order" : 5,
22+
"parameter" : "LEVEL",
23+
"preformatted" : false,
24+
"type" : "char",
25+
"values" : [ "LEVEL" ]
26+
},
27+
"ROLE" : {
28+
"label" : "Role",
29+
"length" : 32,
30+
"order" : 4,
31+
"parameter" : "ROLE",
32+
"preformatted" : false,
33+
"type" : "char",
34+
"values" : [ "ROLE" ]
35+
},
36+
"RelativeImportance" : {
37+
"label" : "Relative Importance",
38+
"length" : 8,
39+
"order" : 3,
40+
"parameter" : "RelativeImportance",
41+
"preformatted" : false,
42+
"type" : "num",
43+
"values" : [ "RelativeImportance" ]
44+
},
45+
"Variable" : {
46+
"label" : "Variable Name",
47+
"length" : 255,
48+
"order" : 2,
49+
"parameter" : "Variable",
50+
"preformatted" : false,
51+
"type" : "char",
52+
"values" : [ "Variable" ]
53+
}
54+
},
55+
"data" : [],
56+
"xInteger" : false,
57+
"yInteger" : false
58+
}

0 commit comments

Comments
 (0)