A PyTorch-based framework for few-shot learning and zero-shot image classification, designed for research and practical applications. This repository provides tools for training, evaluating, and experimenting with models that can generalize to new classes with limited labeled data.
- Few-shot and Zero-shot Learning: Train models to recognize new classes with few or no labeled examples.
- Flexible Data Loading: Custom dataloaders for datasets like Flickr and NUS-WIDE.
- Word Embedding Integration: Uses word2vec for semantic class representations.
- Modular Model Design: Easily extend or modify the backbone and classifier.
- Custom Loss Functions: Includes semantic ranking loss for multi-label classification.
├── dataloaders/ # Data loading utilities
│ └── dataloader.py
├── models/ # Model architectures
│ └── model.py
├── utils/ # Utility functions and custom losses
│ ├── loss.py
│ └── util.py
├── preprocess.py # Data preprocessing scripts
├── train.py # Training script
├── LICENSE
└── README.md
- Python 3.7+
- PyTorch >= 1.7
- torchvision
- numpy
- OpenCV
- PIL
Install dependencies:
pip install torch torchvision numpy opencv-python pillow
- Download the NUS-WIDE or Flickr dataset.
- Place the dataset in your desired directory and update paths in
preprocess.py
anddataloaders/dataloader.py
. - Run preprocessing:
python preprocess.py
Edit train.py
to set your training configuration, then run:
python train.py --epochs 50 --saved_path ./checkpoints
The main model is based on ResNet-101, with a custom classifier that projects image features into a semantic space using word2vec embeddings.
- To use a different backbone, modify
models/model.py
. - To add new loss functions, edit
utils/loss.py
.
If you use this codebase in your research, please cite this repository.
This project is licensed under the MIT License - see the LICENSE file for details.
Pull requests are welcome! For major changes, please open an issue first to discuss what you would like to change.
Maintainer: Md Hasan