diff --git a/pytorchtools.py b/pytorchtools.py index 9644e4b..42a4ec1 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -26,6 +26,7 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra self.delta = delta self.path = path self.trace_func = trace_func + def __call__(self, val_loss, model): score = -val_loss @@ -47,5 +48,10 @@ def save_checkpoint(self, val_loss, model): '''Saves model when validation loss decrease.''' if self.verbose: self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') - torch.save(model.state_dict(), self.path) + + if isinstance(model, torch.nn.DataParallel): + torch.save(model.module.state_dict(), self.path) + else: + torch.save(model.state_dict(), self.path) + self.val_loss_min = val_loss