A simple ViT implementantion for training and evaluating a ViT model. The model can be trained on a single GPU in a less than an hour in consumer GPUs. The repository is structured as follows:
├── config.py <- ViT configuration
|
├── model.py <- ViT model, inference and evaluation ~200 lines
|
└── train.py <- training the model ~150 lines
microViT is based on the raw implementation of the ViT model, with the encoder part. The task we will be working on is the numbers classification task for MNIST dataset.
During the training you will see how the model converges: the loss will decrease and the accuracy will increase.
pip install torch torchvision transformers wandb
Dependencies:
- pytorch: Define the ViT model
- transformers: Cosine lr scheduler
- wandb: Logging the training process
The training can be started with the following command:
python train.py
This will start the training with the default parameters. You can change the parameters in the config.py
file and the train.py
file. The training will be logged on wandb. At the end of the training, the model will be saved in a checkpoints
folder.
- My blog post on the ViT model.