Skip to content

Advanced Vision Transformer implementation that achieves 97.50% CIFAR-10 accuracy with minimal overfitting (2.07% gap) through strategic regularization. Features real-time inference, custom image uploads, confidence analysis, and demonstrates effective fine-tuning of pre-trained transformers for small datasets.

License

Notifications You must be signed in to change notification settings

Avaneesh40585/CIFAR10-Classification

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

2 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

πŸ–ΌοΈ CIFAR-10 Image Classification with Vision Transformer (ViT)

This project presents a state-of-the-art computer vision solution that tackles the challenging problem of small-scale image classification using Vision Transformers (ViT). Unlike traditional convolutional neural networks, this implementation leverages the power of transformer architecture - originally designed for natural language processing - to achieve exceptional performance on visual recognition tasks.

πŸ“‹ Table of Contents

  1. About the Dataset
  2. Project Structure
  3. Model Architecture
  4. Key Features
  5. Training Pipeline Overview
  6. Requirements
  7. Results
  8. Interactive Inference
  9. License
  10. Contributing

πŸ“Š About the Dataset

This system is built for CIFAR-10 Classification:

  • 32x32 RGB images with 10 distinct classes: airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck
  • Multi-class classification: Classes labeled 0-9 for comprehensive object recognition
  • 50,000 training images and 10,000 test images for robust evaluation
  • Enhanced data augmentation including random flips, rotations, color jittering.
  • ImageNet normalization for optimal transformer performance with pre-trained weights

πŸ—‚ Project Structure

CIFAR10-ViT-Classification/
β”œβ”€β”€ .gitignore
β”œβ”€β”€ requirements.txt
β”œβ”€β”€ LICENSE
β”œβ”€β”€ README.md                       # Project documentation (this file)
β”œβ”€β”€ data/
β”‚   └── cifar-10/                   # Auto-downloaded CIFAR-10 dataset
β”œβ”€β”€ best_vit_model.pth              # Available in Releases section with tag 'deit-small'
└── cifar10_vit_classifier.py       # Complete implementation

🧠 Model Architecture

Model: DeiT-Small architecture (22M parameters, 4.8M trainable) with layer freezing, dropout regularization, label smoothing, and weight decay, achieving 97.50% test accuracy with 2.07% generalization gap across 10 object classes.

CIFAR-10 Image (32x32x3)
        ↓
Resize to 224x224 (ViT Input Size)
        ↓
ViT Patch Embedding (16x16 patches β†’ 384D)
        ↓
Position Embeddings + [CLS] Token
        ↓
First 4 Layers: FROZEN (Pre-trained Features)
        ↓
8 Trainable Transformer Encoder Layers
        ↓
Multi-Head Self-Attention (384D hidden)
        ↓
Layer Normalization + Feed Forward
        ↓
[CLS] Token β†’ Classification Head with Dropout
        ↓
Dropout(0.3) β†’ Linear (384 β†’ 10) β†’ Softmax
        ↓
Output: [Class 0-9 Probabilities]

Architecture Specifications:

  • Base Model: Facebook DeiT-Small-Patch16-224 (Data-efficient Image Transformer)
  • Total Parameters: 22M (4.8M trainable after layer freezing)
  • Patch Size: 16Γ—16 pixels with 384-dimensional embeddings
  • Frozen Layers: First 4 transformer blocks + patch embeddings (83% parameter reduction)
  • Classification Head: Dropout (30%) + Linear layer for regularization

✨ Key Features

Anti-Overfitting Techniques:

  • Layer Freezing Strategy: First 4 transformer layers frozen to preserve pre-trained features
  • Dropout Regularization: 30% dropout in classification head
  • Label Smoothing: 10% smoothing to prevent overconfident predictions
  • Weight Decay: 0.05 L2 regularization coefficient
  • Enhanced Data Augmentation: Random horizontal flip, rotation, color jitter, random erasing
  • Early Stopping: Patience-based stopping with 0.001 minimum delta

Training Optimizations:

  • Mixed Precision Training: FP16 automatic mixed precision for efficiency
  • Gradient Accumulation: 2 steps (effective batch size: 128)
  • Learning Rate Scheduling: Linear warmup (15%) with decay
  • AdamW Optimizer: Transformer-optimized with lr=3e-5, weight_decay=0.05
  • Reproducible Training: Fixed random seeds (111) for consistent results

Evaluation & Inference:

  • Comprehensive Metrics: Precision, Recall, F1-score (macro & weighted)
  • Confidence Analysis: Distribution and calibration assessment
  • Confusion Matrix: Detailed error pattern analysis
  • Interactive Inference: Upload custom images, analyze test samples
  • Visual Predictions: Sample prediction visualization with confidence scores

πŸ” Training Pipeline Overview

1. πŸ“‚ Data Preprocessing

  • Enhanced Augmentation: Random horizontal flip, rotation (10Β°), color jitter, random erasing
  • Normalization: ImageNet statistics for pre-trained model compatibility
  • Efficient Loading: 64 batch size with multi-worker DataLoaders

2. πŸ—οΈ Model Setup with Regularization

  • DeiT-Small: Facebook's efficient transformer with 22M parameters
  • Layer Freezing: First 4 layers frozen (4.8M trainable parameters)
  • Classification Head: Dropout + Linear layer with 10 outputs

3. βš™οΈ Training Configuration

  • Loss Function: Label smoothing cross-entropy (0.1 smoothing)
  • Optimizer: AdamW with weight decay (0.05) and optimized learning rate (3e-5)
  • Scheduler: Linear warmup (15%) followed by linear decay
  • Mixed Precision: Automatic gradient scaling for FP16 training

4. πŸ‹οΈ Training Loop with Monitoring

  • Gradient Accumulation: 2 steps for memory-efficient training
  • Real-time Monitoring: Loss, learning rate, and generalization gap tracking
  • Early Stopping: Prevents overfitting with 3 epochs patience
  • Model Checkpointing: Saves best performing model automatically

5. πŸ“Š Comprehensive Evaluation

  • Multi-metric Analysis: Accuracy, Precision, Recall, F1-score
  • Per-class Performance: Detailed breakdown for all 10 classes
  • Confidence Calibration: Model reliability assessment
  • Confusion Analysis: Most common misclassification patterns

βš™οΈ Requirements

torch>=2.0.0
torchvision>=0.15.0
transformers>=4.30.0
numpy>=1.21.0
matplotlib>=3.5.0
scikit-learn>=1.1.0
Pillow>=9.0.0

Install requirements: pip install -r requirements.txt

πŸ“ˆ Results

Overall Performance:

  • Test Accuracy: 97.50% (excellent classification performance)
  • Generalization Gap: 2.07% (outstanding overfitting control)
  • Training Epochs: 8 epochs with stable convergence
  • Macro F1-Score: 97.2% (balanced across all classes)
  • Weighted F1-Score: 97.5% (support-weighted performance)

Per-Class Performance Analysis:

Excellent Performance (>97% accuracy):

  • Frog: 98.90% | Automobile: 98.50% | Ship: 98.50% | Airplane: 98.40%
  • Horse: 98.40% | Deer: 97.70% | Truck: 97.40% | Bird: 97.30%

Strong Performance (94-97% accuracy):

  • Cat: 95.10% | Dog: 94.80%

Confidence Analysis:

  • Mean Confidence: 0.943 (well-calibrated predictions)
  • High Confidence (>0.9): 95%+ accuracy (excellent calibration)
  • Inference Speed: 0.06ms per image (real-time capable)

Common Confusion Patterns:

  1. Dog ↔ Cat: 69 total misclassifications (expected similarity)
  2. Truck ↔ Automobile: 36 cases (vehicle category overlap)
  3. Ship β†’ Airplane: 11 cases (shape similarity in small images)

Model Efficiency:

  • Parameter Efficiency: 83% reduction through layer freezing
  • Memory Optimization: Gradient accumulation enables large effective batch sizes
  • Training Speed: Mixed precision provides significant acceleration
  • Model Size: Compact 22M parameters with excellent performance

πŸ” Interactive Inference

The implementation includes a comprehensive inference system:

Features:

  • Upload Custom Images: Process and classify your own images
  • Sample Analysis: Analyze random or specific test samples
  • Visual Predictions: Grid display with confidence scores and correctness indicators
  • Probability Distribution: Top-5 predictions with confidence bars
  • Detailed Metrics: Per-sample confidence analysis and prediction gaps

Interactive Options:

  1. Random Sample Analysis: Visualize model predictions on test samples
  2. Custom Image Upload: Upload and classify your own images
  3. Index-based Analysis: Examine specific test samples by index
  4. Comprehensive Visualization: 12-sample grid with detailed prediction information

πŸ“„ License

This project is licensed under the MIT License.
See the LICENSE file for details.

🀝 Contributing

πŸ’‘ Opportunities for Contribution:

  • Advanced Regularization: Experiment with CutMix, MixUp, or AugMax techniques
  • Model Variants: Test ViT-Tiny, Swin Transformer, or ConvNeXt architectures
  • Ensemble Methods: Combine multiple transformer models for improved accuracy
  • MLOps Integration: Add Weights & Biases tracking, Docker containerization, ONNX export
  • Hyperparameter Optimization: Implement Optuna or Ray Tune for automated tuning

πŸ”§ How to Contribute:

  1. Fork the repository
  2. Create a feature branch
git checkout -b feature/new-enhancement
  1. Implement changes with comprehensive testing and performance benchmarks
  2. Submit a pull request with detailed description and accuracy improvements

⭐ If this project helps you build better image classification systems with Vision Transformers, consider giving it a star!

About

Advanced Vision Transformer implementation that achieves 97.50% CIFAR-10 accuracy with minimal overfitting (2.07% gap) through strategic regularization. Features real-time inference, custom image uploads, confidence analysis, and demonstrates effective fine-tuning of pre-trained transformers for small datasets.

Topics

Resources

License

Stars

Watchers

Forks