Skip to content

Commit f93864c

Browse files
committed
Minor bug fixes
1 parent 893f99b commit f93864c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

keras_wrapper/extra/callbacks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def __init__(self,
8989
each_n_epochs=1,
9090
max_eval_samples=None,
9191
extra_vars=None,
92-
normalize=True,
92+
normalize=False,
9393
normalization_type=None,
9494
output_types=None,
9595
is_text=False,
@@ -261,7 +261,7 @@ def __init__(self,
261261

262262
else:
263263
# Convert min_pred_multilabel to list
264-
if isinstance(self.min_pred_multilabel, list):
264+
if not isinstance(self.min_pred_multilabel, list):
265265
self.min_pred_multilabel = [self.min_pred_multilabel for _ in self.gt_pos]
266266

267267
super(EvalPerformance, self).__init__()

keras_wrapper/extra/evaluation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
253253
accuracy = sklearn_metrics.accuracy_score(y_gt, y_pred)
254254
accuracy_balanced = sklearn_metrics.accuracy_score(y_gt, y_pred, sample_weight=sample_weights)
255255
# Compute Precision, Recall and F1 score
256-
avrg = extra_vars.get('average_mode', None)
256+
avrg = extra_vars.get('average_mode', 'macro')
257257
precision, recall, f1, _ = sklearn_metrics.precision_recall_fscore_support(y_gt, y_pred, average=avrg)
258258
# Compute Confusion Matrix
259259
cf = sklearn_metrics.confusion_matrix(np.argmax(y_gt, -1), np.argmax(y_pred, -1))
@@ -271,7 +271,7 @@ def multiclass_metrics(pred_list, verbose, extra_vars, split):
271271
# Compute top 5 fp classes
272272
top5_fps = np.argpartition(cf * neg_identity, -5)[:, -5:][:, ::-1]
273273
# Compute top 5 accuracy
274-
arg_top5_pred = np.argpartition(y_pred, -5)[:, -5:]
274+
arg_top5_pred = np.argpartition(pred_list, -5)[:, -5:]
275275
arg_gt = np.argmax(y_gt, -1)
276276
top5_acc = np.mean(np.max(arg_top5_pred == np.repeat(np.expand_dims(arg_gt, -1), 5, -1), -1))
277277

0 commit comments

Comments
 (0)