From f2837dd08b93f76d472b46b69bef4ab8b11cfc0d Mon Sep 17 00:00:00 2001 From: Bjarten Date: Mon, 14 Oct 2024 17:06:01 +0900 Subject: [PATCH 1/7] feat: use val best val loss instaed of best score --- pytorchtools.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 4a8db9c..770b1d4 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -3,6 +3,7 @@ import numpy as np import torch + class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): @@ -10,7 +11,7 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra Args: patience (int): How long to wait after last time validation loss improved. Default: 7 - verbose (bool): If True, prints a message for each validation loss improvement. + verbose (bool): If True, prints a message for each validation loss improvement. Default: False delta (float): Minimum change in the monitored quantity to qualify as an improvement. Default: 0 @@ -22,26 +23,25 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra self.patience = patience self.verbose = verbose self.counter = 0 - self.best_score = None + self.best_val_loss = None self.early_stop = False - self.val_loss_min = np.Inf + self.val_loss_min = np.inf self.delta = delta self.path = path self.trace_func = trace_func - def __call__(self, val_loss, model): - score = -val_loss + def __call__(self, val_loss, model): - if self.best_score is None: - self.best_score = score + if self.best_val_loss is None: + self.best_val_loss = val_loss self.save_checkpoint(val_loss, model) - elif score < self.best_score + self.delta: + elif val_loss > self.best_val_loss - self.delta: self.counter += 1 self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True else: - self.best_score = score + self.best_val_loss = val_loss self.save_checkpoint(val_loss, model) self.counter = 0 From 2e6e56a4fb79eaec9af3e7fbe9f21f787803c8a6 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Mon, 14 Oct 2024 18:22:03 +0900 Subject: [PATCH 2/7] ci: add unit tests --- tests/__init__.py | 0 tests/test_early_stopping.py | 333 +++++++++++++++++++++++++++++++++++ 2 files changed, 333 insertions(+) create mode 100644 tests/__init__.py create mode 100644 tests/test_early_stopping.py diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py new file mode 100644 index 0000000..eb75d3d --- /dev/null +++ b/tests/test_early_stopping.py @@ -0,0 +1,333 @@ +# test_early_stopping.py + +import pytest +from unittest.mock import Mock, patch +import torch +import numpy as np + +from pytorchtools import EarlyStopping + +# Fixtures to mock model and temporary checkpoint path + +@pytest.fixture +def mock_model(): + """ + Fixture to create a mock PyTorch model. + + Returns: + Mock: A mocked PyTorch model with a predefined state_dict. + """ + model = Mock(spec=torch.nn.Module) + model.state_dict.return_value = {'param': 'value'} + return model + +@pytest.fixture +def temp_checkpoint_path(tmp_path): + """ + Fixture to create a temporary checkpoint file path. + + Args: + tmp_path (Path): Pytest's built-in fixture providing a temporary directory. + + Returns: + str: String representation of the temporary checkpoint file path. + """ + checkpoint_file = tmp_path / "checkpoint.pt" + return str(checkpoint_file) + +# Tests + +def test_initialization(): + """ + Test the initialization of the EarlyStopping class to ensure all attributes are set correctly. + + This test verifies that upon instantiation, the EarlyStopping object has its parameters and internal + counters initialized as expected. + """ + # Initialize EarlyStopping with specific parameters + early_stopping = EarlyStopping(patience=5, verbose=True, delta=0.01) + + # Assert that all attributes are correctly initialized + assert early_stopping.patience == 5, "Patience should be set to 5" + assert early_stopping.verbose is True, "Verbose should be True" + assert early_stopping.delta == 0.01, "Delta should be set to 0.01" + assert early_stopping.counter == 0, "Counter should be initialized to 0" + assert early_stopping.best_val_loss is None, "Best score should be None initially" + assert early_stopping.early_stop is False, "Early stop flag should be False initially" + assert early_stopping.val_loss_min == np.inf, "Initial val_loss_min should be infinity" + +def test_initial_call_saves_checkpoint(mock_model, temp_checkpoint_path): + """ + Test that the initial call to EarlyStopping saves the model checkpoint and updates best_val_loss. + + This test ensures that when EarlyStopping is called for the first time with a validation loss, + it correctly saves the model's state_dict and updates the best_val_loss accordingly. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the torch.save method used inside EarlyStopping to prevent actual file I/O + with patch('pytorchtools.torch.save') as mock_save: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping(patience=5, verbose=False, path=temp_checkpoint_path) + + # Simulate the first validation loss + initial_val_loss = 1.0 + early_stopping(initial_val_loss, mock_model) + + # Assert that model.state_dict() was called once + mock_model.state_dict.assert_called_once_with() + # Assert that torch.save was called once with correct arguments + mock_save.assert_called_once_with(mock_model.state_dict(), temp_checkpoint_path) + # Assert that best_val_loss was set correctly + assert early_stopping.best_val_loss == initial_val_loss, "Best score should be set to negative initial_val_loss" + # Assert that early_stop is not triggered + assert not early_stopping.early_stop, "Early stop should not be triggered on initial call" + +def test_validation_loss_improves(mock_model, temp_checkpoint_path): + """ + Test that validation loss improvements trigger checkpoint saving and reset the patience counter. + + This test simulates a sequence of validation losses that consistently improve. It verifies that + EarlyStopping saves the model checkpoint each time an improvement is detected and that the + patience counter remains at zero, indicating no need to stop early. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the torch.save method used inside EarlyStopping + with patch('pytorchtools.torch.save') as mock_save: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping(patience=5, verbose=False, path=temp_checkpoint_path) + + # Simulate a series of improving validation losses + losses = [1.0, 0.9, 0.8, 0.85, 0.75] + for loss in losses: + early_stopping(loss, mock_model) + + # Assert that torch.save was called once for the initial loss and three times for improvements + assert mock_save.call_count == 4, "Checkpoints should be saved on initial and 3 improvements" + # Assert that the patience counter remains at 0 after each improvement + assert early_stopping.counter == 0, "Counter should be reset to 0 after improvements" + # Assert that early stopping was not triggered + assert not early_stopping.early_stop, "Early stop should not be triggered when losses improve" + +def test_validation_loss_no_improvement_within_delta(mock_model, temp_checkpoint_path): + """ + Test that the patience counter increments when validation loss does not improve beyond delta. + + This test simulates a scenario where validation losses do not improve sufficiently (i.e., the + improvement is less than the specified delta). It verifies that the patience counter increments + correctly and that early stopping is triggered once the patience threshold is exceeded. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the save_checkpoint method to monitor its calls directly + with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping( + patience=3, # Number of epochs to wait for improvement + verbose=False, # Disable verbose output + delta=0.01, # Minimum improvement to qualify as an improvement + path=temp_checkpoint_path # Path to save the checkpoint + ) + + # First sequence of validation losses with sufficient improvements + initial_losses = [1.0, 0.98, 0.97, 0.97, 0.95, 0.95] + for loss in initial_losses: + early_stopping(loss, mock_model) + + # After processing initial_losses: + # - Checkpoints should be saved on initial call and on each improvement (Epochs 1, 2, 3, 5) + # - Total save_checkpoint calls: 2 + # - Patience counter should have incremented to 1 (from Epoch 6) + # - Early stopping should not have been triggered yet + + # Assert that save_checkpoint was called three times: initial call and three improvements + assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and three improvements" + + # Assert that the patience counter has incremented to 1 (only Epoch 6 incremented it) + assert early_stopping.counter == 1, "Counter should be incremented to 1" + + # Assert that early stopping has not been triggered yet + assert early_stopping.early_stop is False, "Early stop should not be triggered yet" + + # Second sequence of validation losses that do not improve beyond delta + worsening_losses = [1.0, 1.1] + for loss in worsening_losses: + early_stopping(loss, mock_model) + + # After processing worsening_losses: + # - No checkpoints should be saved since losses are worsening + # - Patience counter should increment by 2 (from 1 to 3) + # - Early stopping should be triggered as patience is exceeded + + # Assert that save_checkpoint was still called three times (no new saves) + assert mock_save_checkpoint.call_count == 3, "No additional checkpoints should be saved as losses worsen" + + # Assert that the patience counter has incremented to 3 + assert early_stopping.counter == 3, "Counter should be incremented to 3" + + # Assert that early stopping was triggered after patience was exceeded + assert early_stopping.early_stop is True, "Early stop should be triggered after patience is exceeded" + + +def test_early_stopping_triggered(mock_model, temp_checkpoint_path): + """ + Test that early stopping is triggered when the patience is exceeded without sufficient improvement. + + This test verifies that when the validation loss does not improve for a number of consecutive epochs + equal to the patience parameter, EarlyStopping sets the early_stop flag to True, indicating that + training should cease. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the save_checkpoint method used inside EarlyStopping to monitor its calls directly + with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping( + patience=2, # Number of epochs to wait for improvement + verbose=True, # Enable verbose output to observe messages + path=temp_checkpoint_path # Path to save the checkpoint + ) + + # Simulate validation losses with no improvement after two patience steps + losses = [1.0, 0.9, 0.85, 0.85, 0.85] + for loss in losses: + early_stopping(loss, mock_model) + + # After processing losses: + # - Checkpoints should be saved on initial call and on each improvement (Epochs 1, 2, 3) + # - Total save_checkpoint calls: 3 + # - Patience counter should have incremented to 2 (from Epochs 4 and 5) + # - Early stopping should be triggered + + # Assert that save_checkpoint was called three times: initial call and two improvements + assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and two improvements" + + # Assert that the patience counter has incremented to 2 (from Epochs 4 and 5) + assert early_stopping.counter == 2, "Counter should be incremented to 2" + + # Assert that early stopping was triggered after patience was exceeded + assert early_stopping.early_stop is True, "Early stop should be triggered after patience is exceeded" + +def test_verbose_output(mock_model, temp_checkpoint_path, capsys): + """ + Test that verbose outputs are printed correctly when verbose is enabled. + + This test checks that when the verbose parameter is set to True, EarlyStopping prints + informative messages about validation loss improvements and patience counter increments. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + capsys: Pytest fixture to capture output to stdout and stderr. + """ + # Patch the torch.save method used inside EarlyStopping to prevent actual file I/O + with patch('pytorchtools.torch.save'): + # Initialize EarlyStopping with verbose enabled + early_stopping = EarlyStopping(patience=2, verbose=True, path=temp_checkpoint_path) + + # Simulate validation losses with and without improvement + losses = [1.0, 0.95, 0.95, 0.95] + for loss in losses: + early_stopping(loss, mock_model) + + # Capture the output printed by the trace_func + captured = capsys.readouterr() + # Check that verbose messages for validation loss decrease are printed + assert 'Validation loss decreased' in captured.out, "Should print validation loss decrease message" + # Check that verbose messages for counter increments are printed + assert 'EarlyStopping counter: 1 out of 2' in captured.out, "Should print first counter increment" + assert 'EarlyStopping counter: 2 out of 2' in captured.out, "Should print second counter increment" + + +def test_no_early_stop_when_validation_improves_within_patience(mock_model, temp_checkpoint_path): + """ + Test that early stopping is not triggered when validation loss continues to improve within patience. + + This test ensures that as long as validation loss keeps improving (even intermittently), the + patience counter does not reach the threshold, and early stopping is not triggered. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the torch.save method used inside EarlyStopping + with patch('pytorchtools.torch.save') as mock_save: + # Initialize EarlyStopping with specified parameters + early_stopping = EarlyStopping(patience=3, verbose=False, path=temp_checkpoint_path) + + # Simulate validation losses with intermittent improvements + losses = [1.0, 0.95, 0.90, 0.85, 0.80, 0.85, 0.80, 0.75] + for loss in losses: + early_stopping(loss, mock_model) + + # Assert that torch.save was called on initial and four improvements + assert mock_save.call_count == 6, "Checkpoints should be saved on initial and five improvements" + # Assert that the patience counter is reset to 0 after each improvement + assert early_stopping.counter == 0, "Counter should be reset to 0 after each improvement" + # Assert that early stopping was not triggered + assert not early_stopping.early_stop, "Early stop should not be triggered when validations continue to improve within patience" + +def test_delta_functionality(mock_model, temp_checkpoint_path): + """ + Test that the delta parameter correctly influences the checkpoint saving behavior. + + This test verifies that when the improvement in validation loss exceeds the delta threshold, + EarlyStopping saves the model checkpoint and resets the patience counter. However, once early stopping is + triggered, the early_stop flag remains True even after a significant improvement. + + Args: + mock_model (Mock): A mocked PyTorch model. + temp_checkpoint_path (str): Temporary file path for saving the checkpoint. + """ + # Patch the save_checkpoint method of EarlyStopping to monitor its calls directly + with patch.object(EarlyStopping, 'save_checkpoint') as mock_save_checkpoint: + # Ensure that the mocked model's state_dict returns a valid state + mock_model.state_dict.return_value = {'layer1.weight': torch.tensor([1, 2, 3])} + + # Initialize EarlyStopping with delta set to 0.1 + early_stopping = EarlyStopping( + patience=3, # Number of epochs to wait for improvement + verbose=False, # Disable verbose output + delta=0.1, # Minimum improvement to qualify as an improvement + path=temp_checkpoint_path # Path to save the checkpoint + ) + + # Simulate validation losses with varying improvements + losses = [1.0, 0.95, 0.91, 0.9, 0.79] + for loss in losses: + early_stopping(loss, mock_model) + + # Expected behavior: + # Initial loss: 1.0 (save) + # 0.95: improvement of 0.05 (less than delta=0.1) -> no save, counter=1 + # 0.91: improvement of 0.04 (less than delta=0.1) -> no save, counter=2 + # 0.9: improvement of 0.01 (less than delta=0.1) -> no save, counter=3 + # 0.79: improvement of 0.11 (larger than delta=0.1) -> save, counter reset to 0, but early_stop still True + + # Assert that save_checkpoint was called on initial and final significant improvement + assert mock_save_checkpoint.call_count == 2, "Checkpoints should be saved on initial and final significant improvements (delta=0.1)" + + # Assert that the patience counter was reset after the final improvement + assert early_stopping.counter == 0, "Counter should be reset after significant improvements" + + # Simulate further validation losses that do not improve beyond delta + worsening_losses = [0.8, 1.1] + for loss in worsening_losses: + early_stopping(loss, mock_model) + + # Since early stop was already triggered earlier, it should remain True + assert early_stopping.early_stop, "Early stop should remain True after it was triggered" + + # Assert no new checkpoints were saved after early stopping was triggered + assert mock_save_checkpoint.call_count == 2, "No additional checkpoints should be saved after early stopping was triggered" + + From eb947c61281a57d530d049f1b656dd81354ab096 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 11:16:37 +0900 Subject: [PATCH 3/7] fix: handle best_val_loss correct --- pytorchtools.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/pytorchtools.py b/pytorchtools.py index 770b1d4..39e6d2f 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -35,18 +35,20 @@ def __call__(self, val_loss, model): if self.best_val_loss is None: self.best_val_loss = val_loss self.save_checkpoint(val_loss, model) - elif val_loss > self.best_val_loss - self.delta: + elif val_loss < self.best_val_loss - self.delta: + # Significant improvement detected + self.best_val_loss = val_loss + self.save_checkpoint(val_loss, model) + self.counter = 0 # Reset counter since improvement occurred + else: + # No significant improvement self.counter += 1 self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}') if self.counter >= self.patience: self.early_stop = True - else: - self.best_val_loss = val_loss - self.save_checkpoint(val_loss, model) - self.counter = 0 def save_checkpoint(self, val_loss, model): - '''Saves model when validation loss decrease.''' + '''Saves model when validation loss decreases.''' if self.verbose: self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...') torch.save(model.state_dict(), self.path) From 8ff0acde48c4d24764ce1ba3b3eaf523408e16cb Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 11:18:29 +0900 Subject: [PATCH 4/7] ci: add tests workflow --- .github/workflows/release.yml | 2 +- .github/workflows/tests.yml | 41 +++++++++++++++++++++++++++++++++++ requirements-dev.txt | 6 +++++ requirements.txt | 1 + tests/test_early_stopping.py | 4 ++-- 5 files changed, 51 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/tests.yml create mode 100644 requirements-dev.txt diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 9a9c45d..738adb3 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -36,7 +36,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install python-semantic-release==9.11.0 # Locking to a specific version + pip install python-semantic-release - name: Configure Git run: | diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 0000000..cb4ad4a --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,41 @@ +name: Python Application Tests + +on: + push: + branches: [ main ] + pull_request: + branches: [ main ] + +jobs: + test: + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ['3.9', '3.10', '3.11', '3.12'] + + steps: + # Step 1: Checkout the repository + - name: Checkout repository + uses: actions/checkout@v4 + + # Step 2: Set up Python + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: 'pip' + + # Step 3: Install dependencies + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements.txt + + # Step 4: Install pytest (if not in requirements.txt) + - name: Install pytest + run: pip install pytest + + # Step 5: Run tests + - name: Run tests + run: pytest diff --git a/requirements-dev.txt b/requirements-dev.txt new file mode 100644 index 0000000..0ea841c --- /dev/null +++ b/requirements-dev.txt @@ -0,0 +1,6 @@ +# requirements-dev.txt +-r requirements.txt +pytest +pytest-mock +coverage +flake8 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 54da486..1848a4e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,3 +1,4 @@ +# requirements.txt matplotlib numpy torchvision diff --git a/tests/test_early_stopping.py b/tests/test_early_stopping.py index eb75d3d..cb6fe5e 100644 --- a/tests/test_early_stopping.py +++ b/tests/test_early_stopping.py @@ -220,10 +220,10 @@ def test_early_stopping_triggered(mock_model, temp_checkpoint_path): def test_verbose_output(mock_model, temp_checkpoint_path, capsys): """ Test that verbose outputs are printed correctly when verbose is enabled. - + This test checks that when the verbose parameter is set to True, EarlyStopping prints informative messages about validation loss improvements and patience counter increments. - + Args: mock_model (Mock): A mocked PyTorch model. temp_checkpoint_path (str): Temporary file path for saving the checkpoint. From 2703ca5819753ae790f44ea61262b85d02e82ae5 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 11:25:55 +0900 Subject: [PATCH 5/7] chore: update requirements-dev.txt --- requirements-dev.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 0ea841c..147b29d 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,6 +1,4 @@ # requirements-dev.txt -r requirements.txt pytest -pytest-mock -coverage -flake8 \ No newline at end of file +pytest-mock \ No newline at end of file From 4be4f6e8606cc905d1a4a6a8d99f215cfe4e19e7 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 11:26:30 +0900 Subject: [PATCH 6/7] refactor: add double newline after imports --- pytorchtools.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorchtools.py b/pytorchtools.py index 39e6d2f..40a59ac 100644 --- a/pytorchtools.py +++ b/pytorchtools.py @@ -4,6 +4,7 @@ import numpy as np import torch + class EarlyStopping: """Early stops the training if validation loss doesn't improve after a given patience.""" def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print): From 348122e032f121c720356e0958df6e4f816c1d30 Mon Sep 17 00:00:00 2001 From: Bjarten Date: Tue, 15 Oct 2024 12:18:11 +0900 Subject: [PATCH 7/7] chore: uodate notebook --- MNIST_Early_Stopping_example.ipynb | 197 +++++++++++++++-------------- 1 file changed, 104 insertions(+), 93 deletions(-) diff --git a/MNIST_Early_Stopping_example.ipynb b/MNIST_Early_Stopping_example.ipynb index 7785f52..cf6b430 100644 --- a/MNIST_Early_Stopping_example.ipynb +++ b/MNIST_Early_Stopping_example.ipynb @@ -108,7 +108,7 @@ " (fc1): Linear(in_features=784, out_features=128, bias=True)\n", " (fc2): Linear(in_features=128, out_features=128, bias=True)\n", " (fc3): Linear(in_features=128, out_features=10, bias=True)\n", - " (dropout): Dropout(p=0.5)\n", + " (dropout): Dropout(p=0.5, inplace=False)\n", ")\n" ] } @@ -266,7 +266,7 @@ " break\n", " \n", " # load the last checkpoint with the best model\n", - " model.load_state_dict(torch.load('checkpoint.pt'))\n", + " model.load_state_dict(torch.load('checkpoint.pt', weights_only=True))\n", "\n", " return model, avg_train_losses, avg_valid_losses" ] @@ -280,100 +280,113 @@ "name": "stdout", "output_type": "stream", "text": [ - "[ 1/100] train_loss: 0.84499 valid_loss: 0.29977\n", - "[ 2/100] train_loss: 0.36182 valid_loss: 0.22742\n", - "Validation loss decreased (inf --> 0.227419). Saving model ...\n", - "[ 3/100] train_loss: 0.29205 valid_loss: 0.19163\n", - "Validation loss decreased (0.227419 --> 0.191628). Saving model ...\n", - "[ 4/100] train_loss: 0.25390 valid_loss: 0.16771\n", - "Validation loss decreased (0.191628 --> 0.167714). Saving model ...\n", - "[ 5/100] train_loss: 0.22671 valid_loss: 0.15422\n", - "Validation loss decreased (0.167714 --> 0.154222). Saving model ...\n", - "[ 6/100] train_loss: 0.20862 valid_loss: 0.14546\n", - "Validation loss decreased (0.154222 --> 0.145459). Saving model ...\n", - "[ 7/100] train_loss: 0.19482 valid_loss: 0.13821\n", - "Validation loss decreased (0.145459 --> 0.138206). Saving model ...\n", - "[ 8/100] train_loss: 0.18431 valid_loss: 0.13398\n", - "Validation loss decreased (0.138206 --> 0.133979). Saving model ...\n", - "[ 9/100] train_loss: 0.17554 valid_loss: 0.12953\n", - "Validation loss decreased (0.133979 --> 0.129535). Saving model ...\n", - "[ 10/100] train_loss: 0.16785 valid_loss: 0.12202\n", - "Validation loss decreased (0.129535 --> 0.122023). Saving model ...\n", - "[ 11/100] train_loss: 0.16202 valid_loss: 0.12249\n", + "[ 1/100] train_loss: 0.83783 valid_loss: 0.31137\n", + "Validation loss decreased (inf --> 0.311369). Saving model ...\n", + "[ 2/100] train_loss: 0.37005 valid_loss: 0.23503\n", + "Validation loss decreased (0.311369 --> 0.235031). Saving model ...\n", + "[ 3/100] train_loss: 0.29645 valid_loss: 0.19488\n", + "Validation loss decreased (0.235031 --> 0.194881). Saving model ...\n", + "[ 4/100] train_loss: 0.25427 valid_loss: 0.17373\n", + "Validation loss decreased (0.194881 --> 0.173730). Saving model ...\n", + "[ 5/100] train_loss: 0.23387 valid_loss: 0.15799\n", + "Validation loss decreased (0.173730 --> 0.157987). Saving model ...\n", + "[ 6/100] train_loss: 0.21357 valid_loss: 0.14634\n", + "Validation loss decreased (0.157987 --> 0.146339). Saving model ...\n", + "[ 7/100] train_loss: 0.20209 valid_loss: 0.13834\n", + "Validation loss decreased (0.146339 --> 0.138335). Saving model ...\n", + "[ 8/100] train_loss: 0.18811 valid_loss: 0.13291\n", + "Validation loss decreased (0.138335 --> 0.132905). Saving model ...\n", + "[ 9/100] train_loss: 0.17865 valid_loss: 0.12881\n", + "Validation loss decreased (0.132905 --> 0.128807). Saving model ...\n", + "[ 10/100] train_loss: 0.17079 valid_loss: 0.12113\n", + "Validation loss decreased (0.128807 --> 0.121127). Saving model ...\n", + "[ 11/100] train_loss: 0.16392 valid_loss: 0.11816\n", + "Validation loss decreased (0.121127 --> 0.118159). Saving model ...\n", + "[ 12/100] train_loss: 0.16054 valid_loss: 0.11620\n", + "Validation loss decreased (0.118159 --> 0.116205). Saving model ...\n", + "[ 13/100] train_loss: 0.15220 valid_loss: 0.11247\n", + "Validation loss decreased (0.116205 --> 0.112469). Saving model ...\n", + "[ 14/100] train_loss: 0.14709 valid_loss: 0.11378\n", "EarlyStopping counter: 1 out of 20\n", - "[ 12/100] train_loss: 0.15300 valid_loss: 0.11852\n", - "Validation loss decreased (0.122023 --> 0.118516). Saving model ...\n", - "[ 13/100] train_loss: 0.14965 valid_loss: 0.11560\n", - "Validation loss decreased (0.118516 --> 0.115598). Saving model ...\n", - "[ 14/100] train_loss: 0.14680 valid_loss: 0.11387\n", - "Validation loss decreased (0.115598 --> 0.113867). Saving model ...\n", - "[ 15/100] train_loss: 0.13988 valid_loss: 0.11728\n", + "[ 15/100] train_loss: 0.14460 valid_loss: 0.11054\n", + "Validation loss decreased (0.112469 --> 0.110544). Saving model ...\n", + "[ 16/100] train_loss: 0.13830 valid_loss: 0.10726\n", + "Validation loss decreased (0.110544 --> 0.107258). Saving model ...\n", + "[ 17/100] train_loss: 0.13689 valid_loss: 0.10741\n", "EarlyStopping counter: 1 out of 20\n", - "[ 16/100] train_loss: 0.13641 valid_loss: 0.11269\n", - "Validation loss decreased (0.113867 --> 0.112686). Saving model ...\n", - "[ 17/100] train_loss: 0.12957 valid_loss: 0.11237\n", - "Validation loss decreased (0.112686 --> 0.112374). Saving model ...\n", - "[ 18/100] train_loss: 0.12862 valid_loss: 0.11198\n", - "Validation loss decreased (0.112374 --> 0.111975). Saving model ...\n", - "[ 19/100] train_loss: 0.12581 valid_loss: 0.10924\n", - "Validation loss decreased (0.111975 --> 0.109242). Saving model ...\n", - "[ 20/100] train_loss: 0.12171 valid_loss: 0.10836\n", - "Validation loss decreased (0.109242 --> 0.108363). Saving model ...\n", - "[ 21/100] train_loss: 0.12191 valid_loss: 0.10922\n", + "[ 18/100] train_loss: 0.13122 valid_loss: 0.10676\n", + "Validation loss decreased (0.107258 --> 0.106764). Saving model ...\n", + "[ 19/100] train_loss: 0.12802 valid_loss: 0.10354\n", + "Validation loss decreased (0.106764 --> 0.103535). Saving model ...\n", + "[ 20/100] train_loss: 0.12939 valid_loss: 0.10285\n", + "Validation loss decreased (0.103535 --> 0.102854). Saving model ...\n", + "[ 21/100] train_loss: 0.12022 valid_loss: 0.10037\n", + "Validation loss decreased (0.102854 --> 0.100375). Saving model ...\n", + "[ 22/100] train_loss: 0.12399 valid_loss: 0.09700\n", + "Validation loss decreased (0.100375 --> 0.097001). Saving model ...\n", + "[ 23/100] train_loss: 0.11651 valid_loss: 0.10024\n", "EarlyStopping counter: 1 out of 20\n", - "[ 22/100] train_loss: 0.11935 valid_loss: 0.10976\n", + "[ 24/100] train_loss: 0.11879 valid_loss: 0.09823\n", "EarlyStopping counter: 2 out of 20\n", - "[ 23/100] train_loss: 0.11901 valid_loss: 0.11053\n", + "[ 25/100] train_loss: 0.11619 valid_loss: 0.10081\n", "EarlyStopping counter: 3 out of 20\n", - "[ 24/100] train_loss: 0.11420 valid_loss: 0.10901\n", + "[ 26/100] train_loss: 0.11205 valid_loss: 0.10083\n", "EarlyStopping counter: 4 out of 20\n", - "[ 25/100] train_loss: 0.11089 valid_loss: 0.10837\n", + "[ 27/100] train_loss: 0.11331 valid_loss: 0.10078\n", "EarlyStopping counter: 5 out of 20\n", - "[ 26/100] train_loss: 0.11008 valid_loss: 0.10944\n", + "[ 28/100] train_loss: 0.10997 valid_loss: 0.10114\n", "EarlyStopping counter: 6 out of 20\n", - "[ 27/100] train_loss: 0.10801 valid_loss: 0.10665\n", - "Validation loss decreased (0.108363 --> 0.106647). Saving model ...\n", - "[ 28/100] train_loss: 0.10433 valid_loss: 0.10248\n", - "Validation loss decreased (0.106647 --> 0.102475). Saving model ...\n", - "[ 29/100] train_loss: 0.10323 valid_loss: 0.10621\n", + "[ 29/100] train_loss: 0.10804 valid_loss: 0.09626\n", + "Validation loss decreased (0.097001 --> 0.096262). Saving model ...\n", + "[ 30/100] train_loss: 0.10787 valid_loss: 0.10086\n", "EarlyStopping counter: 1 out of 20\n", - "[ 30/100] train_loss: 0.10484 valid_loss: 0.10775\n", + "[ 31/100] train_loss: 0.10434 valid_loss: 0.09898\n", "EarlyStopping counter: 2 out of 20\n", - "[ 31/100] train_loss: 0.09985 valid_loss: 0.10616\n", + "[ 32/100] train_loss: 0.10322 valid_loss: 0.09699\n", "EarlyStopping counter: 3 out of 20\n", - "[ 32/100] train_loss: 0.09898 valid_loss: 0.10479\n", + "[ 33/100] train_loss: 0.09873 valid_loss: 0.09826\n", "EarlyStopping counter: 4 out of 20\n", - "[ 33/100] train_loss: 0.10062 valid_loss: 0.10576\n", + "[ 34/100] train_loss: 0.10127 valid_loss: 0.09545\n", + "Validation loss decreased (0.096262 --> 0.095453). Saving model ...\n", + "[ 35/100] train_loss: 0.09834 valid_loss: 0.09980\n", + "EarlyStopping counter: 1 out of 20\n", + "[ 36/100] train_loss: 0.09554 valid_loss: 0.10243\n", + "EarlyStopping counter: 2 out of 20\n", + "[ 37/100] train_loss: 0.10178 valid_loss: 0.09873\n", + "EarlyStopping counter: 3 out of 20\n", + "[ 38/100] train_loss: 0.09442 valid_loss: 0.10022\n", + "EarlyStopping counter: 4 out of 20\n", + "[ 39/100] train_loss: 0.09483 valid_loss: 0.09984\n", "EarlyStopping counter: 5 out of 20\n", - "[ 34/100] train_loss: 0.09704 valid_loss: 0.10770\n", + "[ 40/100] train_loss: 0.09673 valid_loss: 0.10122\n", "EarlyStopping counter: 6 out of 20\n", - "[ 35/100] train_loss: 0.09850 valid_loss: 0.10542\n", + "[ 41/100] train_loss: 0.09409 valid_loss: 0.10010\n", "EarlyStopping counter: 7 out of 20\n", - "[ 36/100] train_loss: 0.09561 valid_loss: 0.10619\n", + "[ 42/100] train_loss: 0.08885 valid_loss: 0.09946\n", "EarlyStopping counter: 8 out of 20\n", - "[ 37/100] train_loss: 0.09381 valid_loss: 0.10745\n", + "[ 43/100] train_loss: 0.08923 valid_loss: 0.10062\n", "EarlyStopping counter: 9 out of 20\n", - "[ 38/100] train_loss: 0.09363 valid_loss: 0.10487\n", + "[ 44/100] train_loss: 0.08976 valid_loss: 0.10242\n", "EarlyStopping counter: 10 out of 20\n", - "[ 39/100] train_loss: 0.09263 valid_loss: 0.10763\n", + "[ 45/100] train_loss: 0.09064 valid_loss: 0.09567\n", "EarlyStopping counter: 11 out of 20\n", - "[ 40/100] train_loss: 0.09234 valid_loss: 0.10778\n", + "[ 46/100] train_loss: 0.08339 valid_loss: 0.10048\n", "EarlyStopping counter: 12 out of 20\n", - "[ 41/100] train_loss: 0.08485 valid_loss: 0.10319\n", + "[ 47/100] train_loss: 0.08805 valid_loss: 0.10472\n", "EarlyStopping counter: 13 out of 20\n", - "[ 42/100] train_loss: 0.09105 valid_loss: 0.10305\n", + "[ 48/100] train_loss: 0.08987 valid_loss: 0.10191\n", "EarlyStopping counter: 14 out of 20\n", - "[ 43/100] train_loss: 0.08963 valid_loss: 0.10952\n", + "[ 49/100] train_loss: 0.08727 valid_loss: 0.10311\n", "EarlyStopping counter: 15 out of 20\n", - "[ 44/100] train_loss: 0.08887 valid_loss: 0.10615\n", + "[ 50/100] train_loss: 0.08554 valid_loss: 0.10606\n", "EarlyStopping counter: 16 out of 20\n", - "[ 45/100] train_loss: 0.08704 valid_loss: 0.10870\n", + "[ 51/100] train_loss: 0.08553 valid_loss: 0.10852\n", "EarlyStopping counter: 17 out of 20\n", - "[ 46/100] train_loss: 0.08477 valid_loss: 0.10877\n", + "[ 52/100] train_loss: 0.08444 valid_loss: 0.10410\n", "EarlyStopping counter: 18 out of 20\n", - "[ 47/100] train_loss: 0.08397 valid_loss: 0.10682\n", + "[ 53/100] train_loss: 0.08313 valid_loss: 0.10338\n", "EarlyStopping counter: 19 out of 20\n", - "[ 48/100] train_loss: 0.08630 valid_loss: 0.10565\n", + "[ 54/100] train_loss: 0.08114 valid_loss: 0.10507\n", "EarlyStopping counter: 20 out of 20\n", "Early stopping\n" ] @@ -406,14 +419,12 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "
" + "
" ] }, - "metadata": { - "needs_background": "light" - }, + "metadata": {}, "output_type": "display_data" } ], @@ -454,20 +465,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "Test Loss: 0.085796\n", + "Test Loss: 0.092734\n", "\n", - "Test Accuracy of 0: 98% (968/979)\n", - "Test Accuracy of 1: 98% (1121/1133)\n", - "Test Accuracy of 2: 97% (1008/1030)\n", - "Test Accuracy of 3: 97% (986/1008)\n", - "Test Accuracy of 4: 97% (955/980)\n", - "Test Accuracy of 5: 97% (867/890)\n", + "Test Accuracy of 0: 98% (969/979)\n", + "Test Accuracy of 1: 98% (1117/1133)\n", + "Test Accuracy of 2: 97% (1004/1030)\n", + "Test Accuracy of 3: 97% (987/1008)\n", + "Test Accuracy of 4: 97% (952/980)\n", + "Test Accuracy of 5: 94% (843/890)\n", "Test Accuracy of 6: 97% (932/956)\n", - "Test Accuracy of 7: 97% (997/1027)\n", - "Test Accuracy of 8: 96% (939/973)\n", - "Test Accuracy of 9: 95% (967/1008)\n", + "Test Accuracy of 7: 97% (1002/1027)\n", + "Test Accuracy of 8: 97% (945/973)\n", + "Test Accuracy of 9: 95% (965/1008)\n", "\n", - "Test Accuracy (Overall): 97% (9740/9984)\n" + "Test Accuracy (Overall): 97% (9716/9984)\n" ] } ], @@ -529,9 +540,9 @@ "outputs": [ { "data": { - "image/png": "\n", + "image/png": "", "text/plain": [ - "
" + "
" ] }, "metadata": {}, @@ -541,7 +552,7 @@ "source": [ "# obtain one batch of test images\n", "dataiter = iter(test_loader)\n", - "images, labels = dataiter.next()\n", + "images, labels = next(dataiter) # Use the built-in next() function\n", "\n", "# get sample outputs\n", "output = model(images)\n", @@ -553,10 +564,10 @@ "# plot the images in the batch, along with predicted and true labels\n", "fig = plt.figure(figsize=(25, 4))\n", "for idx in np.arange(20):\n", - " ax = fig.add_subplot(2, 20/2, idx+1, xticks=[], yticks=[])\n", + " ax = fig.add_subplot(2, 20 // 2, idx+1, xticks=[], yticks=[])\n", " ax.imshow(np.squeeze(images[idx]), cmap='gray')\n", " ax.set_title(\"{} ({})\".format(str(preds[idx].item()), str(labels[idx].item())),\n", - " color=(\"green\" if preds[idx]==labels[idx] else \"red\"))" + " color=(\"green\" if preds[idx]==labels[idx] else \"red\"))\n" ] }, { @@ -576,7 +587,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -590,9 +601,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.6" + "version": "3.12.6" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 }