@@ -530,6 +530,7 @@ def __init__(self,
530
530
patience = 0 ,
531
531
check_split = 'val' ,
532
532
metric_check = 'acc' ,
533
+ want_to_minimize = False ,
533
534
eval_on_epochs = True ,
534
535
each_n_epochs = 1 ,
535
536
start_eval_on_epoch = 0 ,
@@ -549,6 +550,7 @@ def __init__(self,
549
550
self .eval_on_epochs = eval_on_epochs
550
551
self .start_eval_on_epoch = start_eval_on_epoch
551
552
self .each_n_epochs = each_n_epochs
553
+ self .want_to_minimize = want_to_minimize
552
554
553
555
self .verbose = verbose
554
556
self .cum_update = 0
@@ -588,14 +590,17 @@ def evaluate(self, epoch, counter_name='epoch'):
588
590
warnings .warn ('The chosen metric' + str (self .metric_check ) + ' does not exist;'
589
591
' this reducer works only with a valid metric.' )
590
592
return
591
-
593
+ if self .want_to_minimize :
594
+ current_score = - current_score
592
595
# Check if the best score has been outperformed in the current epoch
593
596
if current_score > self .best_score :
594
597
self .best_epoch = epoch
595
598
self .best_score = current_score
596
599
self .wait = 0
597
600
if self .verbose > 0 :
598
- logging .info ('---current best %s %s: %.3f' % (self .check_split , self .metric_check , current_score ))
601
+ logging .info ('---current best %s %s: %.3f' % (self .check_split , self .metric_check ,
602
+ current_score if not self .want_to_minimize
603
+ else - current_score ))
599
604
600
605
# Stop training if performance has not improved for self.patience epochs
601
606
elif self .patience > 0 :
@@ -604,7 +609,8 @@ def evaluate(self, epoch, counter_name='epoch'):
604
609
if self .wait >= self .patience :
605
610
if self .verbose > 0 :
606
611
logging .info ("---%s %d: early stopping. Best %s found at %s %d: %f" % (
607
- str (counter_name ), epoch , self .metric_check , str (counter_name ), self .best_epoch , self .best_score ))
612
+ str (counter_name ), epoch , self .metric_check , str (counter_name ), self .best_epoch ,
613
+ self .best_score if not self .want_to_minimize else - self .best_score ))
608
614
self .model .stop_training = True
609
615
exit (1 )
610
616
0 commit comments