Skip to content

Commit c148812

Browse files
authored
Merge pull request #86 from lvapeab/master
Option for metric minimization objective in EarlyStop callback
2 parents 82d1fc2 + f50d85c commit c148812

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

keras_wrapper/cnn_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,7 @@ def __train(self, ds, params, state=dict()):
795795
callback_early_stop = EarlyStopping(self,
796796
patience=params['patience'],
797797
metric_check=params['metric_check'],
798+
want_to_minimize=True if 'TER' in params['metric_check'] else False,
798799
eval_on_epochs=params['eval_on_epochs'],
799800
each_n_epochs=params['each_n_epochs'],
800801
start_eval_on_epoch=params['start_eval_on_epoch'])
@@ -878,6 +879,7 @@ def __train_from_samples(self, x, y, params, class_weight=None, sample_weight=No
878879
callback_early_stop = EarlyStopping(self,
879880
patience=params['patience'],
880881
metric_check=params['metric_check'],
882+
want_to_minimize=True if 'TER' in params['metric_check'] else False,
881883
eval_on_epochs=params['eval_on_epochs'],
882884
each_n_epochs=params['each_n_epochs'],
883885
start_eval_on_epoch=params['start_eval_on_epoch'])

keras_wrapper/extra/callbacks.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -530,6 +530,7 @@ def __init__(self,
530530
patience=0,
531531
check_split='val',
532532
metric_check='acc',
533+
want_to_minimize=False,
533534
eval_on_epochs=True,
534535
each_n_epochs=1,
535536
start_eval_on_epoch=0,
@@ -549,6 +550,7 @@ def __init__(self,
549550
self.eval_on_epochs = eval_on_epochs
550551
self.start_eval_on_epoch = start_eval_on_epoch
551552
self.each_n_epochs = each_n_epochs
553+
self.want_to_minimize = want_to_minimize
552554

553555
self.verbose = verbose
554556
self.cum_update = 0
@@ -588,14 +590,17 @@ def evaluate(self, epoch, counter_name='epoch'):
588590
warnings.warn('The chosen metric' + str(self.metric_check) + ' does not exist;'
589591
' this reducer works only with a valid metric.')
590592
return
591-
593+
if self.want_to_minimize:
594+
current_score = -current_score
592595
# Check if the best score has been outperformed in the current epoch
593596
if current_score > self.best_score:
594597
self.best_epoch = epoch
595598
self.best_score = current_score
596599
self.wait = 0
597600
if self.verbose > 0:
598-
logging.info('---current best %s %s: %.3f' % (self.check_split, self.metric_check, current_score))
601+
logging.info('---current best %s %s: %.3f' % (self.check_split, self.metric_check,
602+
current_score if not self.want_to_minimize
603+
else -current_score))
599604

600605
# Stop training if performance has not improved for self.patience epochs
601606
elif self.patience > 0:
@@ -604,7 +609,8 @@ def evaluate(self, epoch, counter_name='epoch'):
604609
if self.wait >= self.patience:
605610
if self.verbose > 0:
606611
logging.info("---%s %d: early stopping. Best %s found at %s %d: %f" % (
607-
str(counter_name), epoch, self.metric_check, str(counter_name), self.best_epoch, self.best_score))
612+
str(counter_name), epoch, self.metric_check, str(counter_name), self.best_epoch,
613+
self.best_score if not self.want_to_minimize else -self.best_score))
608614
self.model.stop_training = True
609615
exit(1)
610616

0 commit comments

Comments
 (0)