Skip to content

Commit 2676f7f

Browse files
committed
Bugfix topK accuracy
1 parent 3cff46d commit 2676f7f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

keras_wrapper/extra/evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
249249
accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred)
250250
acc_top_n = {}
251251
for topn in top_n_accuracies:
252-
acc_top_n[topn] = __top_k_accuracy(y_gt, y_pred, topn)
252+
acc_top_n[topn] = __top_k_accuracy(y_gt, pred_list, topn)
253253
# accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights, )
254254

255255
# The following two lines should both provide the same measure (balanced accuracy)
@@ -287,7 +287,7 @@ def __top_k_accuracy(truths, preds, k):
287287
:param k:
288288
:return:
289289
"""
290-
best_k = np.argsort(preds, axis=1)[:, -k:]
290+
best_k = np.argsort(preds, axis=1)[:, -k:][:, ::-1]
291291
ts = np.argmax(truths, axis=1)
292292
successes = 0
293293
for i in range(ts.shape[0]):

0 commit comments

Comments
 (0)