Skip to content

chore!: rename package, restructure files, and add pip integration #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 13 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
name: Publish Python Package

on:
push:
tags:
- 'v*.*.*' # Automatically publish when a new version tag is pushed

jobs:
publish:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v3

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Cache pip
uses: actions/cache@v3
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('**/requirements.txt') }}
restore-keys: |
${{ runner.os }}-pip-

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install build twine

- name: Build the package
run: |
python -m build

- name: Publish to PyPI
env:
TWINE_USERNAME: __token__
TWINE_PASSWORD: ${{ secrets.PYPI_TOKEN }}
run: |
python -m twine upload --repository pypi dist/*
46 changes: 25 additions & 21 deletions .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,28 +23,32 @@ jobs:
pull-requests: write

steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0

- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12' # Specifying the latest stable Python version
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install python-semantic-release
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install python-semantic-release

- name: Configure Git
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"
- name: Configure Git
run: |
git config user.name "github-actions[bot]"
git config user.email "github-actions[bot]@users.noreply.github.com"

- name: Run semantic-release
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
semantic-release publish ${{ github.event.inputs.dry-run == true && '--dry-run' || '' }}
- name: Run semantic-release
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
run: |
if [ "${{ github.event.inputs.dry-run }}" = "true" ]; then
semantic-release publish --dry-run
else
semantic-release publish
fi
184 changes: 91 additions & 93 deletions MNIST_Early_Stopping_example.ipynb

Large diffs are not rendered by default.

52 changes: 48 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,60 @@
# Early Stopping for PyTorch
Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops. The ```EarlyStopping``` class in ```pytorchtool.py``` is used to create an object to keep track of the validation loss while training a [PyTorch](https://pytorch.org/) model. It will save a checkpoint of the model each time the validation loss decrease. We set the ```patience``` argument in the ```EarlyStopping``` class to how many epochs we want to wait after the last time the validation loss improved before breaking the training loop. There is a simple example of how to use the ```EarlyStopping``` class in the [MNIST_Early_Stopping_example](MNIST_Early_Stopping_example.ipynb) notebook.
Early stopping is a form of regularization used to avoid overfitting on the training dataset. Early stopping keeps track of the validation loss, if the loss stops decreasing for several epochs in a row the training stops. The ```EarlyStopping``` class in ```early_stopping_pytorch/early_stopping.py``` is used to create an object to keep track of the validation loss while training a [PyTorch](https://pytorch.org/) model. It will save a checkpoint of the model each time the validation loss decrease. We set the ```patience``` argument in the ```EarlyStopping``` class to how many epochs we want to wait after the last time the validation loss improved before breaking the training loop. There is a simple example of how to use the ```EarlyStopping``` class in the [MNIST_Early_Stopping_example](MNIST_Early_Stopping_example.ipynb) notebook.

Underneath is a plot from the example notebook, which shows the last checkpoint made by the EarlyStopping object, right before the model started to overfit. It had patience set to 20.

![Loss plot](loss_plot.png?raw=true)

## Usage

You can run this project directly in the browser by clicking this button: [![Binder](https://mybinder.org/badge_logo.svg)](https://mybinder.org/v2/gh/Bjarten/early-stopping-pytorch/master), or you can clone the project to your computer and install the required pip packages specified in the requirements text file.
### 1. Clone the Repository
```bash
git clone https://github.com/your_username/early-stopping-pytorch.git
cd early-stopping-pytorch
```
pip install -r requirements.txt

### 2. Set Up the Virtual Environment
Run the setup script to create a virtual environment and install all necessary dependencies.
```bash
./setup_dev_env.sh
```

### 3. Activate the Virtual Environment
Activate the virtual environment:
```bash
source dev-venv/bin/activate
```

### 4. Install the Package in Editable Mode
Install the package locally in editable mode so you can use it immediately:
```bash
pip install -e .
```

### 5. Use the Package
You can now import and use the package in your Python code:
```python
from early_stopping_pytorch import EarlyStopping
```

---

### Summary of Commands

1. Clone the repository:
`git clone https://github.com/your_username/early-stopping-pytorch.git`

2. Set up the environment:
`./setup_dev_env.sh`

3. Activate the environment:
`source dev-venv/bin/activate`

4. Install the package in editable mode:
`pip install -e .`

5. Optional: Build the package for distribution:
`./build.sh`

## References
The ```EarlyStopping``` class in ```pytorchtool.py``` is inspired by the [ignite EarlyStopping class](https://github.com/pytorch/ignite/blob/master/ignite/handlers/early_stopping.py).
The ```EarlyStopping``` class in ```early_stopping_pytorch/early_stopping.py``` is inspired by the [ignite EarlyStopping class](https://github.com/pytorch/ignite/blob/master/ignite/handlers/early_stopping.py).
3 changes: 3 additions & 0 deletions early_stopping_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .early_stopping import EarlyStopping

__version__ = "0.1.0"
File renamed without changes.
23 changes: 21 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,26 @@
[build-system]
requires = ["setuptools>=61.0", "wheel"]
build-backend = "setuptools.build_meta"

[project]
name = "early-stopping-pytorch"
version = "0.1.0"
description = "A PyTorch utility package for Early Stopping"
readme = "README.md"
authors = [
{ name = "Bjarte Sunde", email = "BjarteSunde@outlook.com" }
]
license = { text = "MIT" } # Update if you use a different license
dependencies = [
"numpy>=1.21",
"torch>=1.9.0"
]

[tool.semantic_release]
version_variable = [
"pytorchtools.py:__version__",
"early_stopping_pytorch/__init__.py:__version__",
"pyproject.toml:project.version"
]
branch = "main"
upload_to_pypi = false
build_command = "pip install build && python -m build"
build_command = "pip install build && python -m build"
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# requirements-dev.txt
-r requirements.txt
pytest
pytest-mock
pytest-mock
build
notebook
31 changes: 31 additions & 0 deletions setup_dev_env.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#!/bin/bash

# Exit immediately if a command exits with a non-zero status.
set -e

# Create a virtual environment if it doesn't already exist
if [ ! -d "venv" ]; then
echo "Creating a virtual environment in venv..."
python3 -m venv venv
echo "Virtual environment created."
else
echo "Virtual environment already exists."
fi

# Activate the virtual environment
source venv/bin/activate

# Upgrade pip and install dependencies
echo "Upgrading pip..."
pip install --upgrade pip

echo "Installing development dependencies..."
pip install -r requirements-dev.txt

echo "Installing runtime dependencies..."
pip install -r requirements.txt

echo "Development environment is set up and ready to go!"

# To active the virtual environment, run the following command
echo "source venv/bin/activate"
23 changes: 10 additions & 13 deletions tests/test_early_stopping.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
# test_early_stopping.py
# tests/test_early_stopping.py

import pytest
from unittest.mock import Mock, patch
import torch
import numpy as np

from pytorchtools import EarlyStopping
from early_stopping_pytorch import EarlyStopping

# Fixtures to mock model and temporary checkpoint path

Expand Down Expand Up @@ -68,7 +67,7 @@ def test_initial_call_saves_checkpoint(mock_model, temp_checkpoint_path):
temp_checkpoint_path (str): Temporary file path for saving the checkpoint.
"""
# Patch the torch.save method used inside EarlyStopping to prevent actual file I/O
with patch('pytorchtools.torch.save') as mock_save:
with patch('early_stopping_pytorch.early_stopping.torch.save') as mock_save:
# Initialize EarlyStopping with specified parameters
early_stopping = EarlyStopping(patience=5, verbose=False, path=temp_checkpoint_path)

Expand All @@ -81,7 +80,7 @@ def test_initial_call_saves_checkpoint(mock_model, temp_checkpoint_path):
# Assert that torch.save was called once with correct arguments
mock_save.assert_called_once_with(mock_model.state_dict(), temp_checkpoint_path)
# Assert that best_val_loss was set correctly
assert early_stopping.best_val_loss == initial_val_loss, "Best score should be set to negative initial_val_loss"
assert early_stopping.best_val_loss == initial_val_loss, "Best score should be set to initial_val_loss"
# Assert that early_stop is not triggered
assert not early_stopping.early_stop, "Early stop should not be triggered on initial call"

Expand All @@ -98,7 +97,7 @@ def test_validation_loss_improves(mock_model, temp_checkpoint_path):
temp_checkpoint_path (str): Temporary file path for saving the checkpoint.
"""
# Patch the torch.save method used inside EarlyStopping
with patch('pytorchtools.torch.save') as mock_save:
with patch('early_stopping_pytorch.early_stopping.torch.save') as mock_save:
# Initialize EarlyStopping with specified parameters
early_stopping = EarlyStopping(patience=5, verbose=False, path=temp_checkpoint_path)

Expand Down Expand Up @@ -143,11 +142,11 @@ def test_validation_loss_no_improvement_within_delta(mock_model, temp_checkpoint

# After processing initial_losses:
# - Checkpoints should be saved on initial call and on each improvement (Epochs 1, 2, 3, 5)
# - Total save_checkpoint calls: 2
# - Total save_checkpoint calls: 3
# - Patience counter should have incremented to 1 (from Epoch 6)
# - Early stopping should not have been triggered yet

# Assert that save_checkpoint was called three times: initial call and three improvements
# Assert that save_checkpoint was called three times: initial call and two improvements
assert mock_save_checkpoint.call_count == 3, "Checkpoints should be saved on initial and three improvements"

# Assert that the patience counter has incremented to 1 (only Epoch 6 incremented it)
Expand Down Expand Up @@ -175,7 +174,6 @@ def test_validation_loss_no_improvement_within_delta(mock_model, temp_checkpoint
# Assert that early stopping was triggered after patience was exceeded
assert early_stopping.early_stop is True, "Early stop should be triggered after patience is exceeded"


def test_early_stopping_triggered(mock_model, temp_checkpoint_path):
"""
Test that early stopping is triggered when the patience is exceeded without sufficient improvement.
Expand Down Expand Up @@ -230,7 +228,7 @@ def test_verbose_output(mock_model, temp_checkpoint_path, capsys):
capsys: Pytest fixture to capture output to stdout and stderr.
"""
# Patch the torch.save method used inside EarlyStopping to prevent actual file I/O
with patch('pytorchtools.torch.save'):
with patch('early_stopping_pytorch.early_stopping.torch.save'):
# Initialize EarlyStopping with verbose enabled
early_stopping = EarlyStopping(patience=2, verbose=True, path=temp_checkpoint_path)

Expand All @@ -247,7 +245,6 @@ def test_verbose_output(mock_model, temp_checkpoint_path, capsys):
assert 'EarlyStopping counter: 1 out of 2' in captured.out, "Should print first counter increment"
assert 'EarlyStopping counter: 2 out of 2' in captured.out, "Should print second counter increment"


def test_no_early_stop_when_validation_improves_within_patience(mock_model, temp_checkpoint_path):
"""
Test that early stopping is not triggered when validation loss continues to improve within patience.
Expand All @@ -260,7 +257,7 @@ def test_no_early_stop_when_validation_improves_within_patience(mock_model, temp
temp_checkpoint_path (str): Temporary file path for saving the checkpoint.
"""
# Patch the torch.save method used inside EarlyStopping
with patch('pytorchtools.torch.save') as mock_save:
with patch('early_stopping_pytorch.early_stopping.torch.save') as mock_save:
# Initialize EarlyStopping with specified parameters
early_stopping = EarlyStopping(patience=3, verbose=False, path=temp_checkpoint_path)

Expand Down Expand Up @@ -357,4 +354,4 @@ def test_validation_loss_nan(mock_model, temp_checkpoint_path):
assert not early_stopping.early_stop, "Early stop should not be triggered when validations improve"

# Assert that the patience counter was not incremented for NaN loss
assert early_stopping.counter == 0, "Counter should remain 0 since NaN loss was ignored"
assert early_stopping.counter == 0, "Counter should remain 0 since NaN loss was ignored"