Skip to content

fix: add check for NaN validation loss in EarlyStopping #28

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytorchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
self.trace_func = trace_func

def __call__(self, val_loss, model):
# Check if validation loss is nan
if np.isnan(val_loss):
self.trace_func("Validation loss is NaN. Ignoring this epoch.")
return

if self.best_val_loss is None:
self.best_val_loss = val_loss
Expand Down
27 changes: 27 additions & 0 deletions tests/test_early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,31 @@ def test_delta_functionality(mock_model, temp_checkpoint_path):
# Assert no new checkpoints were saved after early stopping was triggered
assert mock_save_checkpoint.call_count == 2, "No additional checkpoints should be saved after early stopping was triggered"

def test_validation_loss_nan(mock_model, temp_checkpoint_path):
"""
Test that EarlyStopping ignores epochs where validation loss is NaN.

This test ensures that when a validation loss is NaN, EarlyStopping does not update the model
checkpoint, does not reset the patience counter, and ignores the NaN epoch.
"""
# Patch the save_checkpoint method used inside EarlyStopping
with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint:
# Initialize EarlyStopping with specified parameters
early_stopping = EarlyStopping(patience=3, verbose=False, path=temp_checkpoint_path)

# Simulate validation losses, including NaN
losses = [1.0, 0.95, float('nan'), 0.9]
for loss in losses:
early_stopping(loss, mock_model)

# Assert that save_checkpoint was called three times:
# - Initial call (loss=1.0)
# - Improvement (loss=0.95)
# - Improvement (loss=0.9)
assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and each significant improvement, ignoring NaN"

# Assert that early stopping is not triggered
assert not early_stopping.early_stop, "Early stop should not be triggered when validations improve"

# Assert that the patience counter was not incremented for NaN loss
assert early_stopping.counter == 0, "Counter should remain 0 since NaN loss was ignored"