A toy implementation of a Vision Transformer that formulates attention dynamics as stochastic differential equations (SDEs) with fixed-rank KSVD decomposition.
The standard ViT architecture consists of:
- Patch Embedding: Images are divided into fixed-size patches (e.g., 16ร16) and linearly projected
- Positional Embeddings: Learnable position encodings added to patch embeddings
- Transformer Blocks: Multi-head self-attention + MLP layers
- Classification Head: Final linear layer for classification
Mathematical Formulation:
Attention(Q,K,V) = softmax(QK^T/โd_k)V
where Q, K, V are queries, keys, and values derived from input embeddings.
Instead of computing full attention matrices, SNodeViT uses a low-rank approximation:
Mathematical Formulation:
A_t = โ_{i=1}^{r} ฮป_i(t) ฯ_i(t) โ ฯ_i(t)
Where:
ris the fixed rank (not scalable with sequence length)ฯ_i(t)andฯ_i(t)are time-dependent basis functionsฮป_i(t)are time-dependent scaling factorsโdenotes outer product
Complexity: O(Nยฒ ร r) where N is sequence length and r is fixed rank.
The attention dynamics are modeled as a stochastic differential equation:
Mathematical Formulation:
dX_t = f(X_t, ฮธ(t)) dt + G(X_t, ฯ(t)) dW_t
Where:
X_tis the state at timetf(X_t, ฮธ(t))is the drift term (deterministic evolution)G(X_t, ฯ(t)) dW_tis the diffusion term (stochastic noise)ฮธ(t)andฯ(t)are time-dependent parameters
Integration Method: Euler-Maruyama scheme
X_{t+dt} = X_t + f(X_t, ฮธ(t)) dt + G(X_t, ฯ(t)) โdt N(0,1)
snodevit/
โโโ core/ # Core model components
โ โโโ attention.py # Fixed-rank KSVD attention
โ โโโ layers.py # Continuous blocks and SDE integration
โ โโโ model.py # Main SNodeViT architecture
โโโ training/ # Training utilities
โ โโโ trainer.py # Training loop with uncertainty
โ โโโ loss.py # Loss functions
โโโ data/ # Data loading and preprocessing
โ โโโ datasets.py # Dataset managers
โโโ evaluation/ # Evaluation metrics
โ โโโ metrics.py # Uncertainty and calibration metrics
โโโ configs/ # Configuration files
โ โโโ base.yaml # Base configuration
โ โโโ tiny.yaml # Tiny model variant
โ โโโ small.yaml # Small model variant
โ โโโ base_model.yaml # Base model variant
โ โโโ large.yaml # Large model variant
โโโ scripts/ # Training and evaluation scripts
โ โโโ train.py # Main training script
โ โโโ evaluate.py # Evaluation script
โโโ performance_charts/ # Generated performance charts
โ โโโ training_performance.png
โ โโโ performance_summary.md
โโโ requirements.txt # Python dependencies
โโโ setup.py # Package setup
โโโ LICENSE # MIT License
โโโ README.md # This file
git clone <your-repo>
cd snodevit
pip install -r requirements.txtfrom core import create_snodevit
# Create tiny model
model = create_snodevit(
variant='tiny',
num_classes=10,
img_size=224
)
# Forward pass
x = torch.randn(1, 3, 224, 224)
output = model(x) # [1, 10]# Train tiny model on CIFAR-10
python scripts/train.py \
model.variant=tiny \
data.name=cifar10 \
training.batch_size=128 \
training.learning_rate=2e-3| Variant | Embed Dim | Depth | Heads | Basis | Params |
|---|---|---|---|---|---|
| Tiny | 96 | 6 | 3 | 8 | ~1.2M |
| Small | 192 | 8 | 4 | 16 | ~4.8M |
| Base | 384 | 12 | 6 | 24 | ~19M |
| Large | 768 | 16 | 12 | 32 | ~76M |
- Fixed-rank KSVD: Uses predefined number of basis functions
- Time-dependent weights: Neural networks that generate time-varying parameters
- Stochastic noise: Adds controlled randomness for uncertainty quantification
- Neural SDE: Combines attention and MLP in continuous-time formulation
- Euler integration: Simple numerical integration scheme
- Memory efficient: Gradient checkpointing support
- Sinusoidal embedding: Converts time to high-dimensional representation
- MLP processing: Generates time-dependent parameters
- Stable initialization: Small weights for numerical stability
- Epoch 0: Loss: 2.05, Accuracy: 22.8%
- Epoch 99: Loss: 0.46, Accuracy: 84.0%
- Final: 79.4% validation accuracy
- Training Accuracy: 84.0% (+61.2 percentage points improvement)
- Validation Accuracy: 79.4%
- Expected Calibration Error (ECE): 0.0279 (excellent uncertainty calibration)
- Efficient: Gradient checkpointing enabled
- Problem: Rank
rdoesn't scale with sequence lengthN - Impact: Not truly scalable for larger models
- Trade-off: Memory efficiency vs. expressiveness
- Reality: Uses same patch embedding as standard ViT
- Innovation: Only in attention mechanism and SDE integration
- Trade-off: Simplicity vs. novelty
- Method: Simple Euler scheme
- Stability: Basic drift clipping and noise control
- Trade-off: Simplicity vs. accuracy
# Get predictions with uncertainty
mean_pred, uncertainty = model.forward_with_uncertainty(x, num_samples=5)
print(f"Prediction: {mean_pred.argmax()}")
print(f"Uncertainty: {uncertainty.mean():.3f}")# Extract attention maps
attention_maps = model.get_attention_maps(x)
print(f"Number of attention layers: {len(attention_maps)}")The attention matrix A is approximated as:
A โ UฮฃV^T
where U and V are orthogonal matrices and ฮฃ is diagonal with r non-zero singular values.
SDEs model systems with both deterministic and random evolution:
dX_t = ฮผ(X_t, t) dt + ฯ(X_t, t) dW_t
Numerical approximation for SDEs:
X_{n+1} = X_n + ฮผ(X_n, t_n) ฮt + ฯ(X_n, t_n) โฮt Z_n
where Z_n ~ N(0,1).
- Understanding: How SDEs can model neural dynamics
- Experimentation: Testing uncertainty quantification methods
- Learning: Deep learning with continuous-time formulations
- CIFAR-10/100: Image classification tasks
- Prototyping: Quick experiments with novel architectures
- Benchmarking: Comparing against standard ViT
- Confidence estimation: Model prediction reliability
- Risk assessment: Identifying uncertain predictions
- Adaptive computation: Early stopping for easy examples
def compute_adaptive_rank(self, N, D, H):
"""Scale rank with input dimensions."""
optimal_rank = int(math.sqrt(N * D) * 0.1)
return min(optimal_rank, min(N, D, H))- Runge-Kutta methods: Higher-order integration schemes
- Adaptive time stepping: Dynamic step size selection
- Stochastic solvers: More sophisticated noise handling
- Time-varying kernels: Patches that evolve over time
- Adaptive patch sizes: Dynamic receptive fields
- Continuous convolutions: Smooth spatial transformations
- Vision Transformer: "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale"
- Neural ODEs: "Neural Ordinary Differential Equations"
- SDEs: "Stochastic Differential Equations: An Introduction with Applications"
- Low-Rank Attention: "Efficient Attention: Attention with Linear Complexities"
This is a toy project for educational purposes. Feel free to:
- Experiment with different architectures
- Implement the suggested improvements
- Share your findings and insights
- Use as a starting point for research
MIT License - feel free to use for research and education.
Note: This implementation focuses on clarity and educational value rather than production performance. The fixed-rank limitation and basic SDE integration make it suitable for understanding the concepts rather than achieving state-of-the-art results.
- Python 3.8+
- PyTorch 2.0+
- CUDA (optional, for GPU training)
# Clone the repository
git clone <your-github-repo-url>
cd snodevit
# Create virtual environment
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
# Install dependencies
pip install -r requirements.txt
# Install in development mode
pip install -e .# Train tiny model on CIFAR-10
python scripts/train.py \
model.variant=tiny \
data.name=cifar10 \
training.batch_size=128 \
training.learning_rate=2e-3
# Evaluate trained model
python scripts/evaluate.py \
--model-path checkpoints/best.pth \
--config configs/tiny.yaml# If you want to regenerate performance charts from new training logs
python generate_performance_charts.py- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
- Bug Reports: Use the Issues tab for bug reports
- Feature Requests: Open an issue for new features
- Questions: Use Discussions for questions and help
- Improvements: Submit PRs for code improvements
Happy Learning! ๐

