This project presents a state-of-the-art computer vision solution that tackles the challenging problem of small-scale image classification using Vision Transformers (ViT). Unlike traditional convolutional neural networks, this implementation leverages the power of transformer architecture - originally designed for natural language processing - to achieve exceptional performance on visual recognition tasks.
- About the Dataset
- Project Structure
- Model Architecture
- Key Features
- Training Pipeline Overview
- Requirements
- Results
- Interactive Inference
- License
- Contributing
This system is built for CIFAR-10 Classification:
- 32x32 RGB images with 10 distinct classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
- Multi-class classification: Classes labeled 0-9 for comprehensive object recognition
- 50,000 training images and 10,000 test images for robust evaluation
- Enhanced data augmentation including random flips, rotations, color jittering.
- ImageNet normalization for optimal transformer performance with pre-trained weights
CIFAR10-ViT-Classification/
βββ .gitignore
βββ requirements.txt
βββ LICENSE
βββ README.md # Project documentation (this file)
βββ data/
β βββ cifar-10/ # Auto-downloaded CIFAR-10 dataset
βββ best_vit_model.pth # Available in Releases section with tag 'deit-small'
βββ cifar10_vit_classifier.py # Complete implementation
Model: DeiT-Small architecture (22M parameters, 4.8M trainable) with layer freezing, dropout regularization, label smoothing, and weight decay, achieving 97.50% test accuracy with 2.07% generalization gap across 10 object classes.
CIFAR-10 Image (32x32x3)
β
Resize to 224x224 (ViT Input Size)
β
ViT Patch Embedding (16x16 patches β 384D)
β
Position Embeddings + [CLS] Token
β
First 4 Layers: FROZEN (Pre-trained Features)
β
8 Trainable Transformer Encoder Layers
β
Multi-Head Self-Attention (384D hidden)
β
Layer Normalization + Feed Forward
β
[CLS] Token β Classification Head with Dropout
β
Dropout(0.3) β Linear (384 β 10) β Softmax
β
Output: [Class 0-9 Probabilities]
- Base Model: Facebook DeiT-Small-Patch16-224 (Data-efficient Image Transformer)
- Total Parameters: 22M (4.8M trainable after layer freezing)
- Patch Size: 16Γ16 pixels with 384-dimensional embeddings
- Frozen Layers: First 4 transformer blocks + patch embeddings (83% parameter reduction)
- Classification Head: Dropout (30%) + Linear layer for regularization
- Layer Freezing Strategy: First 4 transformer layers frozen to preserve pre-trained features
- Dropout Regularization: 30% dropout in classification head
- Label Smoothing: 10% smoothing to prevent overconfident predictions
- Weight Decay: 0.05 L2 regularization coefficient
- Enhanced Data Augmentation: Random horizontal flip, rotation, color jitter, random erasing
- Early Stopping: Patience-based stopping with 0.001 minimum delta
- Mixed Precision Training: FP16 automatic mixed precision for efficiency
- Gradient Accumulation: 2 steps (effective batch size: 128)
- Learning Rate Scheduling: Linear warmup (15%) with decay
- AdamW Optimizer: Transformer-optimized with lr=3e-5, weight_decay=0.05
- Reproducible Training: Fixed random seeds (111) for consistent results
- Comprehensive Metrics: Precision, Recall, F1-score (macro & weighted)
- Confidence Analysis: Distribution and calibration assessment
- Confusion Matrix: Detailed error pattern analysis
- Interactive Inference: Upload custom images, analyze test samples
- Visual Predictions: Sample prediction visualization with confidence scores
- Enhanced Augmentation: Random horizontal flip, rotation (10Β°), color jitter, random erasing
- Normalization: ImageNet statistics for pre-trained model compatibility
- Efficient Loading: 64 batch size with multi-worker DataLoaders
- DeiT-Small: Facebook's efficient transformer with 22M parameters
- Layer Freezing: First 4 layers frozen (4.8M trainable parameters)
- Classification Head: Dropout + Linear layer with 10 outputs
- Loss Function: Label smoothing cross-entropy (0.1 smoothing)
- Optimizer: AdamW with weight decay (0.05) and optimized learning rate (3e-5)
- Scheduler: Linear warmup (15%) followed by linear decay
- Mixed Precision: Automatic gradient scaling for FP16 training
- Gradient Accumulation: 2 steps for memory-efficient training
- Real-time Monitoring: Loss, learning rate, and generalization gap tracking
- Early Stopping: Prevents overfitting with 3 epochs patience
- Model Checkpointing: Saves best performing model automatically
- Multi-metric Analysis: Accuracy, Precision, Recall, F1-score
- Per-class Performance: Detailed breakdown for all 10 classes
- Confidence Calibration: Model reliability assessment
- Confusion Analysis: Most common misclassification patterns
torch>=2.0.0
torchvision>=0.15.0
transformers>=4.30.0
numpy>=1.21.0
matplotlib>=3.5.0
scikit-learn>=1.1.0
Pillow>=9.0.0
Install requirements: pip install -r requirements.txt
- Test Accuracy: 97.50% (excellent classification performance)
- Generalization Gap: 2.07% (outstanding overfitting control)
- Training Epochs: 8 epochs with stable convergence
- Macro F1-Score: 97.2% (balanced across all classes)
- Weighted F1-Score: 97.5% (support-weighted performance)
Excellent Performance (>97% accuracy):
- Frog: 98.90% | Automobile: 98.50% | Ship: 98.50% | Airplane: 98.40%
- Horse: 98.40% | Deer: 97.70% | Truck: 97.40% | Bird: 97.30%
Strong Performance (94-97% accuracy):
- Cat: 95.10% | Dog: 94.80%
- Mean Confidence: 0.943 (well-calibrated predictions)
- High Confidence (>0.9): 95%+ accuracy (excellent calibration)
- Inference Speed: 0.06ms per image (real-time capable)
- Dog β Cat: 69 total misclassifications (expected similarity)
- Truck β Automobile: 36 cases (vehicle category overlap)
- Ship β Airplane: 11 cases (shape similarity in small images)
- Parameter Efficiency: 83% reduction through layer freezing
- Memory Optimization: Gradient accumulation enables large effective batch sizes
- Training Speed: Mixed precision provides significant acceleration
- Model Size: Compact 22M parameters with excellent performance
The implementation includes a comprehensive inference system:
- Upload Custom Images: Process and classify your own images
- Sample Analysis: Analyze random or specific test samples
- Visual Predictions: Grid display with confidence scores and correctness indicators
- Probability Distribution: Top-5 predictions with confidence bars
- Detailed Metrics: Per-sample confidence analysis and prediction gaps
- Random Sample Analysis: Visualize model predictions on test samples
- Custom Image Upload: Upload and classify your own images
- Index-based Analysis: Examine specific test samples by index
- Comprehensive Visualization: 12-sample grid with detailed prediction information
This project is licensed under the MIT License.
See the LICENSE file for details.
- Advanced Regularization: Experiment with CutMix, MixUp, or AugMax techniques
- Model Variants: Test ViT-Tiny, Swin Transformer, or ConvNeXt architectures
- Ensemble Methods: Combine multiple transformer models for improved accuracy
- MLOps Integration: Add Weights & Biases tracking, Docker containerization, ONNX export
- Hyperparameter Optimization: Implement Optuna or Ray Tune for automated tuning
- Fork the repository
- Create a feature branch
git checkout -b feature/new-enhancement
- Implement changes with comprehensive testing and performance benchmarks
- Submit a pull request with detailed description and accuracy improvements
β If this project helps you build better image classification systems with Vision Transformers, consider giving it a star!