From efe20abb446a0fcafbdc3ed57c2a6d7c08fbe6e6 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 12:53:19 +0900 Subject: [PATCH 1/2] fix: add check for nan in early stopping --- pytorchtools.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pytorchtools.py b/pytorchtools.py index 40a59ac..1ee059e 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -32,6 +32,10 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra self.trace_func = trace_func def __call__(self, val_loss, model): + # Check if validation loss is nan + if np.isnan(val_loss): + self.trace_func("Validation loss is NaN. Ignoring this epoch.") + return if self.best_val_loss is None: self.best_val_loss = val_loss From f4f1b802375f8537ba5ff4923391901be318ba97 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 12:53:48 +0900 Subject: [PATCH 2/2] ci: update tests for new early stopping check --- tests/test_early_stopping.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index cb6fe5e..e072bf4 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -330,4 +330,31 @@ def test_delta_functionality(mock_model, temp_checkpoint_path): # Assert no new checkpoints were saved after early stopping was triggered assert mock_save_checkpoint.call_count == 2, "No additional checkpoints should be saved after early stopping was triggered" +def test_validation_loss_nan(mock_model, temp_checkpoint_path): + """ + Test that EarlyStopping ignores epochs where validation loss is NaN. + + This test ensures that when a validation loss is NaN, EarlyStopping does not update the model + checkpoint, does not reset the patience counter, and ignores the NaN epoch. + """ + # Patch the save_checkpoint method used inside EarlyStopping + with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping(patience=3, verbose=False, path=temp_checkpoint_path) + + # Simulate validation losses, including NaN + losses = [1.0, 0.95, float('nan'), 0.9] + for loss in losses: + early_stopping(loss, mock_model) + + # Assert that save_checkpoint was called three times: + # - Initial call (loss=1.0) + # - Improvement (loss=0.95) + # - Improvement (loss=0.9) + assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and each significant improvement, ignoring NaN" + + # Assert that early stopping is not triggered + assert not early_stopping.early_stop, "Early stop should not be triggered when validations improve" + # Assert that the patience counter was not incremented for NaN loss + assert early_stopping.counter == 0, "Counter should remain 0 since NaN loss was ignored" \ No newline at end of file