Skip to content

Commit 45767d6

Browse files
committed
add benchmark filtering for problem type
1 parent 3fdc255 commit 45767d6

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

docs/benchmarks/ebm-benchmark.ipynb

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -912,7 +912,7 @@
912912
"print(f'Results (pre-filtered) count: {results_df.shape[0]}')\n",
913913
"\n",
914914
"# Optionally filter out results we want to replace\n",
915-
"#results_df = results_df[~(results_df['method'] == 'ebm')]\n",
915+
"#results_df = results_df[results_df['method'] != 'ebm']\n",
916916
"#results_df = results_df[~((results_df['method'] == 'ebm') & (results_df['meta'] == '{}'))]\n",
917917
"print(f'Results (post-filtered) count: {results_df.shape[0]}')"
918918
]
@@ -949,18 +949,19 @@
949949
"metadata": {},
950950
"outputs": [],
951951
"source": [
952-
"# Optionally filter out any incomplete datasets\n",
953-
"#results_df = results_df[~(results_df['task'] == 'Devnagari-Script')]\n",
954-
"print(f'Final count: {results_df.shape[0]}')\n",
955-
"\n",
956-
"\n",
957952
"types_df = results_df[results_df['name'].isin(['auc', 'ovo_auc', 'nrmse'])]\n",
958953
"task_to_type = types_df.groupby('task')['name'].first().map({'auc': 'binary', 'ovo_auc': 'multiclass', 'nrmse': 'regression'})\n",
959954
"results_df['type'] = results_df['task'].map(task_to_type).fillna('')\n",
960955
"\n",
961956
"flip = ['r2', 'auc', 'precision', 'recall', 'accuracy', 'bal_acc', 'ovo_auc', 'ovr_auc', 'mprecision', 'mrecall', 'maccuracy', 'mbal_acc']\n",
962957
"condition = results_df['name'].isin(flip)\n",
963-
"results_df.loc[condition, 'num_val'] = -results_df.loc[condition, 'num_val']"
958+
"results_df.loc[condition, 'num_val'] = -results_df.loc[condition, 'num_val']\n",
959+
"\n",
960+
"\n",
961+
"# Optionally filter out any incomplete datasets\n",
962+
"#results_df = results_df[results_df['task'] != 'Devnagari-Script']\n",
963+
"#results_df = results_df[results_df['type'] == 'regression']\n",
964+
"print(f'Final count: {results_df.shape[0]}')"
964965
]
965966
},
966967
{

0 commit comments

Comments
 (0)