diff --git a/pytorchtools.py b/pytorchtools.py index 9da369d..3e1e432 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -1,9 +1,11 @@ import numpy as np import torch + class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" - def __init__(self, patience=7, verbose=False, delta=0): + + def __init__(self, patience=7, verbose=False, delta=0, mode='min'): """ Args: patience (int): How long to wait after last time validation loss improved. @@ -12,35 +14,64 @@ def __init__(self, patience=7, verbose=False, delta=0): Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 + mode (str): Procedure for determining the best score. """ + self.patience = patience self.verbose = verbose self.counter = 0 - self.best_score = None self.early_stop = False - self.val_loss_min = np.Inf - self.delta = delta + self.mode = mode + + if self.mode == 'min': + self.criterion = np.less + self.delta = - delta + self.best_score = np.Inf + + self.vocab = {'score': 'loss', 'comportement': 'decreased'} - def __call__(self, val_loss, model): + elif self.mode == 'max': + self.criterion = np.greater + self.delta = delta + self.best_score = np.NINF - score = -val_loss + self.vocab = {'score': 'metric', 'comportement': 'increased'} - if self.best_score is None: + else: + raise ValueError( + "mode only takes as value in input 'min' or 'max'") + + def __call__(self, score, model): + """Determines if the score is the best and saves the model if so. + Also manages early stopping. + + Arguments: + score (float): Value of the metric or loss. + model: Pytorch model + """ + if np.isinf(self.best_score): self.best_score = score - self.save_checkpoint(val_loss, model) - elif score < self.best_score - delta: + self.save_checkpoint(score, model) + + elif self.criterion(score, self.best_score + self.delta): + + self.best_score = score + self.save_checkpoint(score, model) + self.counter = 0 + else: self.counter += 1 - print(f'EarlyStopping counter: {self.counter} out of {self.patience}') + print( + f'EarlyStopping counter: {self.counter} out of {self.patience}' + ) if self.counter >= self.patience: self.early_stop = True - else: - self.best_score = score - self.save_checkpoint(val_loss, model) - self.counter = 0 - def save_checkpoint(self, val_loss, model): - '''Saves model when validation loss decrease.''' + def save_checkpoint(self, score, model): + '''Saves the model when the score satisfies the criterion.''' if self.verbose: - print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') + score_name = self.vocab['score'] + comportement = self.vocab['comportement'] + print( + f'Validation {score_name} {comportement} ({self.best_score:.6f} --> {score:.6f}). Saving model ...' + ) torch.save(model.state_dict(), 'checkpoint.pt') - self.val_loss_min = val_loss