@@ -242,24 +242,53 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
242
242
243
243
# Compute accuracy
244
244
accuracy = sklearn_metrics .accuracy_score (y_gt , y_pred )
245
- accuracy_balanced = sklearn_metrics .accuracy_score (y_gt , y_pred , sample_weight = sample_weights )
245
+ accuracy_3 = __top_k_accuracy (y_gt , y_pred , 3 )
246
+ accuracy_5 = __top_k_accuracy (y_gt , y_pred , 5 )
247
+ #accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, )
248
+
249
+ # The following two lines should both provide the same measure (balanced accuracy)
250
+ #_, accuracy_balanced, _, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='macro')
251
+ accuracy_balanced = sklearn_metrics .balanced_accuracy_score (y_gt , y_pred )
252
+
246
253
# Compute Precision, Recall and F1 score
247
254
precision , recall , f1 , _ = sklearn_metrics .precision_recall_fscore_support (y_gt , y_pred , average = 'micro' )
248
255
249
256
if verbose > 0 :
250
257
logging .info ('Accuracy: %f' % accuracy )
258
+ logging .info ('Accuracy top-3: %f' % accuracy_3 )
259
+ logging .info ('Accuracy top-5: %f' % accuracy_5 )
251
260
logging .info ('Balanced Accuracy: %f' % accuracy_balanced )
252
261
logging .info ('Precision: %f' % precision )
253
262
logging .info ('Recall: %f' % recall )
254
263
logging .info ('F1 score: %f' % f1 )
255
264
256
265
return {'accuracy' : accuracy ,
266
+ 'accuracy_top_3' : accuracy_3 ,
267
+ 'accuracy_top_5' : accuracy_5 ,
257
268
'accuracy_balanced' : accuracy_balanced ,
258
269
'precision' : precision ,
259
270
'recall' : recall ,
260
271
'f1' : f1 }
261
272
262
273
274
+ def __top_k_accuracy (truths , preds , k ):
275
+ """
276
+ Both preds and truths are same shape m by n (m is number of predictions and n is number of classes)
277
+
278
+ :param preds:
279
+ :param truths:
280
+ :param k:
281
+ :return:
282
+ """
283
+ best_k = np .argsort (preds , axis = 1 )[:, - k :]
284
+ ts = np .argmax (truths , axis = 1 )
285
+ successes = 0
286
+ for i in range (ts .shape [0 ]):
287
+ if ts [i ] in best_k [i ,:]:
288
+ successes += 1
289
+ return float (successes )/ ts .shape [0 ]
290
+
291
+
263
292
def semantic_segmentation_accuracy (pred_list , verbose , extra_vars , split ):
264
293
"""
265
294
Semantic Segmentation Accuracy metric
0 commit comments