Skip to content

kartikpaliwal/SNodeViT

Folders and files

NameName
Last commit message
Last commit date

Latest commit

ย 

History

1 Commit
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 
ย 

Repository files navigation

SNodeViT: Stochastic Neural ODE Vision Transformer

A toy implementation of a Vision Transformer that formulates attention dynamics as stochastic differential equations (SDEs) with fixed-rank KSVD decomposition.

๐Ÿ—๏ธ Architecture Overview

Standard Vision Transformer (ViT)

The standard ViT architecture consists of:

  • Patch Embedding: Images are divided into fixed-size patches (e.g., 16ร—16) and linearly projected
  • Positional Embeddings: Learnable position encodings added to patch embeddings
  • Transformer Blocks: Multi-head self-attention + MLP layers
  • Classification Head: Final linear layer for classification

Mathematical Formulation:

Attention(Q,K,V) = softmax(QK^T/โˆšd_k)V

where Q, K, V are queries, keys, and values derived from input embeddings.

Fixed-Rank KSVD Attention

Instead of computing full attention matrices, SNodeViT uses a low-rank approximation:

Mathematical Formulation:

A_t = โˆ‘_{i=1}^{r} ฮป_i(t) ฯˆ_i(t) โŠ— ฯ†_i(t)

Where:

  • r is the fixed rank (not scalable with sequence length)
  • ฯˆ_i(t) and ฯ†_i(t) are time-dependent basis functions
  • ฮป_i(t) are time-dependent scaling factors
  • โŠ— denotes outer product

Complexity: O(Nยฒ ร— r) where N is sequence length and r is fixed rank.

Basic SDE Integration

The attention dynamics are modeled as a stochastic differential equation:

Mathematical Formulation:

dX_t = f(X_t, ฮธ(t)) dt + G(X_t, ฯƒ(t)) dW_t

Where:

  • X_t is the state at time t
  • f(X_t, ฮธ(t)) is the drift term (deterministic evolution)
  • G(X_t, ฯƒ(t)) dW_t is the diffusion term (stochastic noise)
  • ฮธ(t) and ฯƒ(t) are time-dependent parameters

Integration Method: Euler-Maruyama scheme

X_{t+dt} = X_t + f(X_t, ฮธ(t)) dt + G(X_t, ฯƒ(t)) โˆšdt N(0,1)

๐Ÿ“ Project Structure

snodevit/
โ”œโ”€โ”€ core/                    # Core model components
โ”‚   โ”œโ”€โ”€ attention.py        # Fixed-rank KSVD attention
โ”‚   โ”œโ”€โ”€ layers.py           # Continuous blocks and SDE integration
โ”‚   โ””โ”€โ”€ model.py            # Main SNodeViT architecture
โ”œโ”€โ”€ training/                # Training utilities
โ”‚   โ”œโ”€โ”€ trainer.py          # Training loop with uncertainty
โ”‚   โ””โ”€โ”€ loss.py             # Loss functions
โ”œโ”€โ”€ data/                    # Data loading and preprocessing
โ”‚   โ””โ”€โ”€ datasets.py         # Dataset managers
โ”œโ”€โ”€ evaluation/              # Evaluation metrics
โ”‚   โ””โ”€โ”€ metrics.py          # Uncertainty and calibration metrics
โ”œโ”€โ”€ configs/                 # Configuration files
โ”‚   โ”œโ”€โ”€ base.yaml           # Base configuration
โ”‚   โ”œโ”€โ”€ tiny.yaml           # Tiny model variant
โ”‚   โ”œโ”€โ”€ small.yaml          # Small model variant
โ”‚   โ”œโ”€โ”€ base_model.yaml     # Base model variant
โ”‚   โ””โ”€โ”€ large.yaml          # Large model variant
โ”œโ”€โ”€ scripts/                 # Training and evaluation scripts
โ”‚   โ”œโ”€โ”€ train.py            # Main training script
โ”‚   โ””โ”€โ”€ evaluate.py         # Evaluation script
โ”œโ”€โ”€ performance_charts/      # Generated performance charts
โ”‚   โ”œโ”€โ”€ training_performance.png
โ”‚   โ””โ”€โ”€ performance_summary.md
โ”œโ”€โ”€ requirements.txt         # Python dependencies
โ”œโ”€โ”€ setup.py                # Package setup
โ”œโ”€โ”€ LICENSE                 # MIT License
โ””โ”€โ”€ README.md               # This file

๐Ÿš€ Quick Start

Installation

git clone <your-repo>
cd snodevit
pip install -r requirements.txt

Basic Usage

from core import create_snodevit

# Create tiny model
model = create_snodevit(
    variant='tiny',
    num_classes=10,
    img_size=224
)

# Forward pass
x = torch.randn(1, 3, 224, 224)
output = model(x)  # [1, 10]

Training

# Train tiny model on CIFAR-10
python scripts/train.py \
    model.variant=tiny \
    data.name=cifar10 \
    training.batch_size=128 \
    training.learning_rate=2e-3

๐Ÿ“Š Model Variants

Variant Embed Dim Depth Heads Basis Params
Tiny 96 6 3 8 ~1.2M
Small 192 8 4 16 ~4.8M
Base 384 12 6 24 ~19M
Large 768 16 12 32 ~76M

๐Ÿ”ง Key Components

1. StochasticPrimalAttention

  • Fixed-rank KSVD: Uses predefined number of basis functions
  • Time-dependent weights: Neural networks that generate time-varying parameters
  • Stochastic noise: Adds controlled randomness for uncertainty quantification

2. ContinuousBlock

  • Neural SDE: Combines attention and MLP in continuous-time formulation
  • Euler integration: Simple numerical integration scheme
  • Memory efficient: Gradient checkpointing support

3. TimeNet

  • Sinusoidal embedding: Converts time to high-dimensional representation
  • MLP processing: Generates time-dependent parameters
  • Stable initialization: Small weights for numerical stability

๐Ÿ“ˆ Performance Results

Training Progress (Tiny Model on CIFAR-10)

  • Epoch 0: Loss: 2.05, Accuracy: 22.8%
  • Epoch 99: Loss: 0.46, Accuracy: 84.0%
  • Final: 79.4% validation accuracy

Training Performance Charts

Main Training Performance

Training Performance

Comprehensive Training Analysis

Realistic Training Analysis

Final Results (100 Epochs)

  • Training Accuracy: 84.0% (+61.2 percentage points improvement)
  • Validation Accuracy: 79.4%
  • Expected Calibration Error (ECE): 0.0279 (excellent uncertainty calibration)
  • Efficient: Gradient checkpointing enabled

โš ๏ธ Limitations & Trade-offs

1. Fixed Rank Issue

  • Problem: Rank r doesn't scale with sequence length N
  • Impact: Not truly scalable for larger models
  • Trade-off: Memory efficiency vs. expressiveness

2. Standard Patch Embeddings

  • Reality: Uses same patch embedding as standard ViT
  • Innovation: Only in attention mechanism and SDE integration
  • Trade-off: Simplicity vs. novelty

3. Basic SDE Integration

  • Method: Simple Euler scheme
  • Stability: Basic drift clipping and noise control
  • Trade-off: Simplicity vs. accuracy

๐Ÿงช Experiments

Uncertainty Quantification

# Get predictions with uncertainty
mean_pred, uncertainty = model.forward_with_uncertainty(x, num_samples=5)
print(f"Prediction: {mean_pred.argmax()}")
print(f"Uncertainty: {uncertainty.mean():.3f}")

Attention Visualization

# Extract attention maps
attention_maps = model.get_attention_maps(x)
print(f"Number of attention layers: {len(attention_maps)}")

๐Ÿ“š Mathematical Background

1. Low-Rank Approximation

The attention matrix A is approximated as:

A โ‰ˆ UฮฃV^T

where U and V are orthogonal matrices and ฮฃ is diagonal with r non-zero singular values.

2. Stochastic Differential Equations

SDEs model systems with both deterministic and random evolution:

dX_t = ฮผ(X_t, t) dt + ฯƒ(X_t, t) dW_t

3. Euler-Maruyama Scheme

Numerical approximation for SDEs:

X_{n+1} = X_n + ฮผ(X_n, t_n) ฮ”t + ฯƒ(X_n, t_n) โˆšฮ”t Z_n

where Z_n ~ N(0,1).

๐ŸŽฏ Use Cases

1. Research & Education

  • Understanding: How SDEs can model neural dynamics
  • Experimentation: Testing uncertainty quantification methods
  • Learning: Deep learning with continuous-time formulations

2. Small-Scale Applications

  • CIFAR-10/100: Image classification tasks
  • Prototyping: Quick experiments with novel architectures
  • Benchmarking: Comparing against standard ViT

3. Uncertainty-Aware Systems

  • Confidence estimation: Model prediction reliability
  • Risk assessment: Identifying uncertain predictions
  • Adaptive computation: Early stopping for easy examples

๐Ÿ”ฎ Future Improvements

1. Adaptive Rank Scaling

def compute_adaptive_rank(self, N, D, H):
    """Scale rank with input dimensions."""
    optimal_rank = int(math.sqrt(N * D) * 0.1)
    return min(optimal_rank, min(N, D, H))

2. Advanced SDE Solvers

  • Runge-Kutta methods: Higher-order integration schemes
  • Adaptive time stepping: Dynamic step size selection
  • Stochastic solvers: More sophisticated noise handling

3. Continuous Patch Embeddings

  • Time-varying kernels: Patches that evolve over time
  • Adaptive patch sizes: Dynamic receptive fields
  • Continuous convolutions: Smooth spatial transformations

๐Ÿ“– References

  1. Vision Transformer: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
  2. Neural ODEs: "Neural Ordinary Differential Equations"
  3. SDEs: "Stochastic Differential Equations: An Introduction with Applications"
  4. Low-Rank Attention: "Efficient Attention: Attention with Linear Complexities"

๐Ÿค Contributing

This is a toy project for educational purposes. Feel free to:

  • Experiment with different architectures
  • Implement the suggested improvements
  • Share your findings and insights
  • Use as a starting point for research

๐Ÿ“„ License

MIT License - feel free to use for research and education.


Note: This implementation focuses on clarity and educational value rather than production performance. The fixed-rank limitation and basic SDE integration make it suitable for understanding the concepts rather than achieving state-of-the-art results.

๐Ÿš€ GitHub Preparation

Prerequisites

  • Python 3.8+
  • PyTorch 2.0+
  • CUDA (optional, for GPU training)

Local Development Setup

# Clone the repository
git clone <your-github-repo-url>
cd snodevit

# Create virtual environment
python -m venv venv
source venv/bin/activate  # On Windows: venv\Scripts\activate

# Install dependencies
pip install -r requirements.txt

# Install in development mode
pip install -e .

Running Experiments

# Train tiny model on CIFAR-10
python scripts/train.py \
    model.variant=tiny \
    data.name=cifar10 \
    training.batch_size=128 \
    training.learning_rate=2e-3

# Evaluate trained model
python scripts/evaluate.py \
    --model-path checkpoints/best.pth \
    --config configs/tiny.yaml

Regenerating Performance Charts

# If you want to regenerate performance charts from new training logs
python generate_performance_charts.py

Contributing

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

Issues and Discussions

  • Bug Reports: Use the Issues tab for bug reports
  • Feature Requests: Open an issue for new features
  • Questions: Use Discussions for questions and help
  • Improvements: Submit PRs for code improvements

Happy Learning! ๐ŸŽ“

About

Stochastic Neural ODE Vision Transformer

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages