Skip to content

Commit a4a41f4

Browse files
committed
Fixed balanced accuracy and added top3 and top5 accuracy calculation.
1 parent 9c5a28e commit a4a41f4

File tree

1 file changed

+30
-1
lines changed

1 file changed

+30
-1
lines changed

keras_wrapper/extra/evaluation.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,24 +242,53 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
242242

243243
# Compute accuracy
244244
accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred)
245-
accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights)
245+
accuracy_3 = __top_k_accuracy(y_gt, y_pred, 3)
246+
accuracy_5 = __top_k_accuracy(y_gt, y_pred, 5)
247+
#accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, )
248+
249+
# The following two lines should both provide the same measure (balanced accuracy)
250+
#_, accuracy_balanced, _, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='macro')
251+
accuracy_balanced = sklearn_metrics.balanced_accuracy_score(y_gt, y_pred)
252+
246253
# Compute Precision, Recall and F1 score
247254
precision, recall, f1, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average='micro')
248255

249256
if verbose > 0:
250257
logging.info('Accuracy: %f' % accuracy)
258+
logging.info('Accuracy top-3: %f' % accuracy_3)
259+
logging.info('Accuracy top-5: %f' % accuracy_5)
251260
logging.info('Balanced Accuracy: %f' % accuracy_balanced)
252261
logging.info('Precision: %f' % precision)
253262
logging.info('Recall: %f' % recall)
254263
logging.info('F1 score: %f' % f1)
255264

256265
return {'accuracy': accuracy,
266+
'accuracy_top_3': accuracy_3,
267+
'accuracy_top_5': accuracy_5,
257268
'accuracy_balanced': accuracy_balanced,
258269
'precision': precision,
259270
'recall': recall,
260271
'f1': f1}
261272

262273

274+
def __top_k_accuracy(truths, preds, k):
275+
"""
276+
Both preds and truths are same shape m by n (m is number of predictions and n is number of classes)
277+
278+
:param preds:
279+
:param truths:
280+
:param k:
281+
:return:
282+
"""
283+
best_k = np.argsort(preds, axis=1)[:, -k:]
284+
ts = np.argmax(truths, axis=1)
285+
successes = 0
286+
for i in range(ts.shape[0]):
287+
if ts[i] in best_k[i,:]:
288+
successes += 1
289+
return float(successes)/ts.shape[0]
290+
291+
263292
def semantic_segmentation_accuracy(pred_list, verbose, extra_vars, split):
264293
"""
265294
Semantic Segmentation Accuracy metric

0 commit comments

Comments
 (0)