From 2e54511f24386310fd66aac737e14fa45305c618 Mon Sep 17 00:00:00 2001 From: Jeffrey Ng Date: Sun, 31 May 2020 20:45:33 -0700 Subject: [PATCH] feat: add general checkpoint --- pytorchtools.py | 21 ++++++++++++++++----- 1 file changed, 16 insertions(+), 5 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 554d2a2..ed02e47 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -3,7 +3,7 @@ class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" - def __init__(self, patience=7, verbose=False, delta=0): + def __init__(self, patience=7, verbose=False, delta=0, general_checkpoint=False): """ Args: patience (int): How long to wait after last time validation loss improved. @@ -12,6 +12,8 @@ def __init__(self, patience=7, verbose=False, delta=0): Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 + general_checkpoint (bool): Saves addition information that can be used to resume training. + Default: False """ self.patience = patience self.verbose = verbose @@ -20,8 +22,9 @@ def __init__(self, patience=7, verbose=False, delta=0): self.early_stop = False self.val_loss_min = np.Inf self.delta = delta + self.general_checkpoint = general_checkpoint - def __call__(self, val_loss, model): + def __call__(self, val_loss, model, epoch=None, optimizer=None): score = -val_loss @@ -35,12 +38,20 @@ def __call__(self, val_loss, model): self.early_stop = True else: self.best_score = score - self.save_checkpoint(val_loss, model) + self.save_checkpoint(val_loss, model, epoch=epoch, optimizer=optimizer) self.counter = 0 - def save_checkpoint(self, val_loss, model): + def save_checkpoint(self, val_loss, model, epoch=None, optimizer=None): '''Saves model when validation loss decrease.''' if self.verbose: print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') - torch.save(model.state_dict(), 'checkpoint.pt') + if self.general_checkpoint and epoch is not None and optimizer is not None: + torch.save({ + 'epoch': epoch, + 'model_state_dict': model.state_dict(), + 'optimizer_state_dict': optimizer.state_dict(), + 'loss': val_loss + }, 'checkpoint.tar') + else: + torch.save(model.state_dict(), 'checkpoint.pt') self.val_loss_min = val_loss