For a visual demonstration of DINO's emergent properties and self-supervised learning capabilities, see the original repo's video:
example.mp4
This video illustrates how DINO learns semantically meaningful features without supervision, enabling impressive clustering and attention visualization effects (source).
This repository provides a from-scratch, research-oriented implementation of DINO (Self-Distillation with No Labels) for Vision Transformers (ViT). The goal is to offer a transparent, modular, and extensible codebase for:
- Experimenting with self-supervised learning (SSL) beyond the constraints of the original Facebook DINO repo
- Integrating DINO with custom datasets, backbones, or loss functions
- Benchmarking and ablation studies
- Gaining a deeper understanding of DINO's mechanisms and design
DINO leverages self-distillation without labels, using a teacher-student paradigm:
- Backbone: Vision Transformer (ViT) via timm, with support for arbitrary patch sizes and input resolutions.
- DINO Head: A 5-layer MLP with GELU activations and LayerNorm, projecting the [CLS] token to a high-dimensional space (default: 1000).
- Multi-crop strategy: Each image yields multiple global (224x224) and local (96x96) crops. The teacher processes only global crops; the student processes both.
- Teacher-Student:
- The student is trained by matching its output distribution to the teacher's for different views of the same image.
- The teacher is updated as an exponential moving average (EMA) of the student.
- DINO Loss:
- Cross-entropy between the softmaxed, temperature-scaled teacher and student outputs, with centering to prevent collapse.
- Teacher outputs are sharpened (low temperature), student outputs are smoothed (higher temperature).
- Centering is updated online to stabilize training.
- No labels are used at any point.
- Full control over data pipeline: Custom
CustomDataset
and collate for multi-crop, easily extensible to other datasets or crop strategies. - Backbone-agnostic: Swap ViT for any timm-compatible model; patch size and input resolution are configurable per model.
- Explicit device and memory management: Designed for large-batch, multi-GPU training; supports gradient accumulation and efficient data loading.
- Loss and EMA logic are modular: Easy to adapt for other SSL paradigms (BYOL, MoCo, etc.) or to experiment with different centering/temperature schedules.
- Minimal external dependencies: Only PyTorch, timm, and tqdm; no reliance on Facebook's DINO codebase.
- Dataset: Replace STL10 in
train.py
with any torchvision or custom dataset. TheCustomDataset
class expects a dataset returning PIL images. - Backbone: Change
model_name
intrain.py
anddino.py
to any timm model. Adjustimg_size
andout_dim
as needed. - Augmentation: Modify
get_global_transforms
andget_local_transforms
indataloader.py
for custom multi-crop strategies. - Hyperparameters: All key parameters (batch size, learning rate, temperatures, EMA momentum, etc.) are defined at the top of
train.py
. - Scaling: Increase
batch_size
andnum_workers
to maximize GPU utilization. Use gradient accumulation for very large effective batch sizes. - Checkpointing: Models and optimizer state are checkpointed every 5 epochs. Adjust frequency as needed.
- Monitoring: Training loss is tracked via tqdm. For more advanced logging, integrate with Weights & Biases or TensorBoard.
- Integration: The modular design allows for easy integration with other SSL methods or downstream tasks.
- Designed for high-throughput, large-batch training on modern GPUs (tested up to 46GB VRAM).
- Persistent DataLoader workers and pin_memory for efficient data transfer.
- Checkpoints include both student and teacher weights, optimizer, and loss.
- For distributed/multi-GPU, adapt the DataLoader and model wrapping as needed.
batch_size
: 2048 (default; scale as memory allows)num_workers
: 16num_epochs
: 10learning_rate
: 0.0005 (AdamW)weight_decay
: 0.04teacher_temp
: 0.04student_temp
: 0.1out_dim
: 1000img_size
: 224 (teacher), 96 (student)ema_momentum
: 0.996 (can be tuned)
- Computational Resources: Training was performed on a single NVIDIA L40S GPU (48GB VRAM), taking ~15 minutes per epoch. This is significantly less compute than the original paper, which used 8 V100 16GB GPUs for multiple days.
- Pre-trained Backbone: To reduce computational requirements, this implementation uses a pre-trained ViT backbone instead of training from scratch like the original paper. While this affects the "true" self-supervised nature, it's a practical compromise for resource-constrained environments.
- Future Improvements:
- Scale to multi-GPU training for larger batch sizes and faster convergence
- Implement true from-scratch training of the ViT backbone
- Add support for more advanced augmentation strategies
- Integrate with modern training frameworks (DeepSpeed, FSDP)
- Experiment with different architectures beyond ViT
Through this implementation, several key insights were gained:
- Architecture Design: Deep understanding of the teacher-student framework and how EMA updates maintain stability
- Memory Management: Practical experience with large-model training, gradient accumulation, and efficient data loading
- Loss Dynamics: Insights into how temperature scaling and centering prevent mode collapse in self-supervised learning
- Resource Optimization: Learned to make practical trade-offs (like using pre-trained backbones) while preserving core algorithmic insights
- Distributed Training: Exposure to the requirements and challenges of scaling to multi-GPU/multi-node setups, including data parallelism, synchronization, and communication overheads. Realized the importance of frameworks like PyTorch DDP, DeepSpeed, and FSDP for efficient scaling.
- Mixed Precision Training: Understanding the benefits and caveats of using mixed precision (FP16/BFloat16) to accelerate training and reduce memory usage, and how to integrate tools like
torch.cuda.amp
. - Reproducibility: Gained appreciation for controlling random seeds, environment variables, and deterministic settings to ensure experiment reproducibility, especially in distributed settings.
- Data Pipeline Bottlenecks: Learned to profile and optimize the data pipeline (disk I/O, augmentation, prefetching, pinning) to keep the GPU fully utilized.
- Scalability Mindset: Adopted a mindset of designing for scalability and robustness from the start, anticipating the needs of future, larger experiments.