From db1591bdf3074c0bb5b0290695d5ff82bb149146 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Tue, 13 Aug 2019 23:35:42 +0200 Subject: [PATCH 1/2] Also supports early stopping for metrics --- pytorchtools.py | 67 ++++++++++++++++++++++++++++++++++++------------- 1 file changed, 49 insertions(+), 18 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 9da369d..74cd4aa 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -1,9 +1,11 @@ import numpy as np import torch + class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" - def __init__(self, patience=7, verbose=False, delta=0): + + def __init__(self, patience=7, verbose=False, delta=0, mode='min'): """ Args: patience (int): How long to wait after last time validation loss improved. @@ -12,35 +14,64 @@ def __init__(self, patience=7, verbose=False, delta=0): Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 + mode (str): Procedure for determining the best score. """ + self.patience = patience self.verbose = verbose self.counter = 0 - self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf - self.delta = delta + self.mode = mode + + if self.mode == 'min': + self.criterion = np.less + self.delta *= delta + self.best_score = np.Inf + + self.vocab = {'score': 'loss', 'comportement': 'decreased'} - def __call__(self, val_loss, model): + elif self.mode == 'max': + self.criterion = np.greater + self.delta = delta + self.best_score = np.NINF - score = -val_loss + self.vocab = {'score': 'metric', 'comportement': 'increased'} - if self.best_score is None: + else: + raise ValueError( + "mode only takes as value in input 'min' or 'max'") + + def __call__(self, score, model): + """Determines if the score is the best and saves the model if so. + Also manages early stopping. + + Arguments: + score (float): Value of the metric or loss. + model: Pytorch model + """ + if np.isinf(self.best_score): self.best_score = score - self.save_checkpoint(val_loss, model) - elif score < self.best_score - delta: + self.save_checkpoint(score, model) + + elif self.criterion(score, self.best_score + delta): + + self.best_score = score + self.save_checkpoint(score, model) + self.counter = 0 + else: self.counter += 1 - print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + print( + f'EarlyStopping counter: {self.counter} out of {self.patience}' + ) if self.counter >= self.patience: self.early_stop = True - else: - self.best_score = score - self.save_checkpoint(val_loss, model) - self.counter = 0 - def save_checkpoint(self, val_loss, model): - '''Saves model when validation loss decrease.''' + def save_checkpoint(self, score, model): + '''Saves the model when the score satisfies the criterion.''' if self.verbose: - print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + score_name = self.vocab['score'] + comportement = self.vocab['comportement'] + print( + f'Validation {score_name} {comportement} ({self.best_score:.6f} --> {score:.6f}). Saving model ...' + ) torch.save(model.state_dict(), 'checkpoint.pt') - self.val_loss_min = val_loss From 63d59203ffe29fe111296c705c4eb0958922eaf7 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Wed, 14 Aug 2019 13:44:58 +0200 Subject: [PATCH 2/2] Fix --- pytorchtools.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 74cd4aa..3e1e432 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -25,7 +25,7 @@ def __init__(self, patience=7, verbose=False, delta=0, mode='min'): if self.mode == 'min': self.criterion = np.less - self.delta *= delta + self.delta = - delta self.best_score = np.Inf self.vocab = {'score': 'loss', 'comportement': 'decreased'} @@ -53,7 +53,7 @@ def __call__(self, score, model): self.best_score = score self.save_checkpoint(score, model) - elif self.criterion(score, self.best_score + delta): + elif self.criterion(score, self.best_score + self.delta): self.best_score = score self.save_checkpoint(score, model)