Skip to content

Commit 7b14935

Browse files
committed
Update pzmm_binary_classification_model_import notebook to include model card generation
1 parent fc85adc commit 7b14935

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

examples/pzmm_binary_classification_model_import.ipynb

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -740,7 +740,7 @@
740740
],
741741
"source": [
742742
"import getpass\n",
743-
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path):\n",
743+
"def write_model_stats(x_train, y_train, test_predict, test_proba, y_test, model, path, prefix):\n",
744744
" # Calculate train predictions\n",
745745
" train_predict = model.predict(x_train)\n",
746746
" train_proba = model.predict_proba(x_train)\n",
@@ -757,6 +757,20 @@
757757
" test_data=test_data, \n",
758758
" json_path=path\n",
759759
" )\n",
760+
"\n",
761+
" full_training_data = pd.concat([y_train.reset_index(drop=True), x_train.reset_index(drop=True)], axis=1)\n",
762+
"\n",
763+
" pzmm.JSONFiles.generate_model_card(\n",
764+
" model_prefix=prefix,\n",
765+
" model_files = path,\n",
766+
" algorithm = str(type(model).__name__),\n",
767+
" train_data = full_training_data,\n",
768+
" train_predictions=train_predict,\n",
769+
" target_type='classification',\n",
770+
" target_value=1,\n",
771+
" interval_vars=predictor_columns,\n",
772+
" selection_statistic='_RASE_',\n",
773+
" )\n",
760774
" \n",
761775
"username = getpass.getpass()\n",
762776
"password = getpass.getpass()\n",
@@ -766,8 +780,8 @@
766780
"\n",
767781
"test_predict = [y_dtc_predict, y_rfc_predict, y_gbc_predict]\n",
768782
"test_proba = [y_dtc_proba, y_rfc_proba, y_gbc_proba]\n",
769-
"for (mod, pred, proba, path) in zip(model, test_predict, test_proba, zip_folder):\n",
770-
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path)"
783+
"for (mod, pred, proba, path, prefix) in zip(model, test_predict, test_proba, zip_folder, model_prefix):\n",
784+
" write_model_stats(x_train, y_train, pred, proba, y_test, mod, path, prefix)"
771785
]
772786
},
773787
{

0 commit comments

Comments
 (0)