Skip to content

feat(ci): add GitHub Action for Python tests, fix EarlyStopping logic, and add unit tests #27

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install python-semantic-release==9.11.0 # Locking to a specific version
pip install python-semantic-release

- name: Configure Git
run: |
Expand Down
41 changes: 41 additions & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
name: Python Application Tests

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

strategy:
matrix:
python-version: ['3.9', '3.10', '3.11', '3.12']

steps:
# Step 1: Checkout the repository
- name: Checkout repository
uses: actions/checkout@v4

# Step 2: Set up Python
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
cache: 'pip'

# Step 3: Install dependencies
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt

# Step 4: Install pytest (if not in requirements.txt)
- name: Install pytest
run: pip install pytest

# Step 5: Run tests
- name: Run tests
run: pytest
197 changes: 104 additions & 93 deletions MNIST_Early_Stopping_example.ipynb

Large diffs are not rendered by default.

29 changes: 16 additions & 13 deletions pytorchtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@

import numpy as np
import torch


class EarlyStopping:
"""Early stops the training if validation loss doesn't improve after a given patience."""
def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
"""
Args:
patience (int): How long to wait after last time validation loss improved.
Default: 7
verbose (bool): If True, prints a message for each validation loss improvement.
verbose (bool): If True, prints a message for each validation loss improvement.
Default: False
delta (float): Minimum change in the monitored quantity to qualify as an improvement.
Default: 0
Expand All @@ -22,31 +24,32 @@ def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', tra
self.patience = patience
self.verbose = verbose
self.counter = 0
self.best_score = None
self.best_val_loss = None
self.early_stop = False
self.val_loss_min = np.Inf
self.val_loss_min = np.inf
self.delta = delta
self.path = path
self.trace_func = trace_func
def __call__(self, val_loss, model):

score = -val_loss
def __call__(self, val_loss, model):

if self.best_score is None:
self.best_score = score
if self.best_val_loss is None:
self.best_val_loss = val_loss
self.save_checkpoint(val_loss, model)
elif score < self.best_score + self.delta:
elif val_loss < self.best_val_loss - self.delta:
# Significant improvement detected
self.best_val_loss = val_loss
self.save_checkpoint(val_loss, model)
self.counter = 0 # Reset counter since improvement occurred
else:
# No significant improvement
self.counter += 1
self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
if self.counter >= self.patience:
self.early_stop = True
else:
self.best_score = score
self.save_checkpoint(val_loss, model)
self.counter = 0

def save_checkpoint(self, val_loss, model):
'''Saves model when validation loss decrease.'''
'''Saves model when validation loss decreases.'''
if self.verbose:
self.trace_func(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
torch.save(model.state_dict(), self.path)
Expand Down
4 changes: 4 additions & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# requirements-dev.txt
-r requirements.txt
pytest
pytest-mock
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# requirements.txt
matplotlib
numpy
torchvision
Empty file added tests/__init__.py
Empty file.
Loading