@@ -253,7 +253,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
253
253
accuracy = sklearn_metrics .accuracy_score (y_gt , y_pred )
254
254
accuracy_balanced = sklearn_metrics .accuracy_score (y_gt , y_pred , sample_weight = sample_weights )
255
255
# Compute Precision, Recall and F1 score
256
- avrg = extra_vars .get ('average_mode' , None )
256
+ avrg = extra_vars .get ('average_mode' , 'macro' )
257
257
precision , recall , f1 , _ = sklearn_metrics .precision_recall_fscore_support (y_gt , y_pred , average = avrg )
258
258
# Compute Confusion Matrix
259
259
cf = sklearn_metrics .confusion_matrix (np .argmax (y_gt , - 1 ), np .argmax (y_pred , - 1 ))
@@ -271,7 +271,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
271
271
# Compute top 5 fp classes
272
272
top5_fps = np .argpartition (cf * neg_identity , - 5 )[:, - 5 :][:, ::- 1 ]
273
273
# Compute top 5 accuracy
274
- arg_top5_pred = np .argpartition (y_pred , - 5 )[:, - 5 :]
274
+ arg_top5_pred = np .argpartition (pred_list , - 5 )[:, - 5 :]
275
275
arg_gt = np .argmax (y_gt , - 1 )
276
276
top5_acc = np .mean (np .max (arg_top5_pred == np .repeat (np .expand_dims (arg_gt , - 1 ), 5 , - 1 ), - 1 ))
277
277
0 commit comments