This repository contains a customizable PyTorch training loop template that simplifies training, validation, and testing of models. It includes support for:
- ✅ Early stopping
- 📉 Learning rate scheduling
- 📊 Metric logging
- 🛑 Graceful
Keyboard Interrupt
handling during training, will return the results up to the current epoch
Instead of rewriting boilerplate code for every project, use this reusable trainer as a solid starting point and adapt it to your specific needs!
A working example using the Iris dataset, including a custom Dataset
class and a simple FNN model, is available in main.py
.
- Copy the
trainer.py
file into your project. - Add the packages listed in
requirements.txt
to your project environment. - By default, your
DataLoader
should return batches as a dictionary with the following keys:{'input': [...], 'target': [...]}
- Modify the
_output_parse(self, output)
function to match the output requirements of your model
from trainer import Trainer
trainer = Trainer(model=model, device='cpu')
results = trainer.fit(
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
criterion=criterion,
max_epochs=100,
early_stopping=True,
early_stopping_monitor='accuracy',
early_stopping_mode='max',
metrics={
'accuracy': sklearn.metrics.accuracy_score
}
)
The trainer.fit
function will output a dictionary containing training, validation, and test losses and metrics.
test_results = trainer.test(test_loader, criterion)
The trainer.test
function will output a dictionary containing test losses and metrics.
y_pred = trainer.predict(test_loader)
The trainer.predict
function accepts both a DataLoader
and a single input, and it outputs the predictions.
This project is under the MIT license. See LICENSE for more information.