You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
In this tutorial we implement from scratch Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We will train this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
17
+
In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
19
18
This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification).
20
19
21
20
+++
@@ -27,9 +26,10 @@ We will need to install the following Python packages:
27
26
-[TorchVision](https://pytorch.org/vision) will be used for image augmentations
28
27
-[grain](https://github.com/google/grain/) will be be used for efficient data loading
29
28
-[tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
29
+
-[Matplotlib](https://matplotlib.org/stable/) will be used for visualization purposes
@@ -116,7 +116,7 @@ class VisionTransformer(nnx.Module):
116
116
117
117
# Encoder blocks
118
118
x = self.encoder(embeddings)
119
-
x = self.lnorm(x)
119
+
x = self.final_norm(x)
120
120
121
121
# fetch the first token
122
122
x = x[:, 0]
@@ -162,9 +162,155 @@ class TransformerEncoder(nnx.Module):
162
162
return x
163
163
164
164
165
-
# We use a configuration to make smaller model to reduce the training time
166
-
x = jnp.ones((4, 120, 120, 3))
167
-
model = VisionTransformer(num_classes=10, num_layers=4, num_heads=4, img_size=120, patch_size=8)
165
+
x = jnp.ones((4, 224, 224, 3))
166
+
model = VisionTransformer(num_classes=1000)
167
+
y = model(x)
168
+
print("Predictions shape: ", y.shape)
169
+
```
170
+
171
+
Let's now load the weights pretrained on the ImageNet dataset using HuggingFace Transformers. We load all weights and check whether we have consistent results with the reference model.
172
+
173
+
```{code-cell} ipython3
174
+
from transformers import FlaxViTForImageClassification
@@ -177,19 +323,19 @@ In the following sections we set up a image classification dataset and train thi
177
323
178
324
In the this tutorial we use [Food 101](https://huggingface.co/datasets/ethz/food101) dataset which consists of 101 food categories, with 101,000 images. For each class, 250 manually reviewed test images are provided as well as 750 training images. On purpose, the training images were not cleaned, and thus still contain some amount of noise. This comes mostly in the form of intense colors and sometimes wrong labels. All images were rescaled to have a maximum side length of 512 pixels.
179
325
180
-
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 10 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
326
+
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 20 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
181
327
182
328
```{code-cell} ipython3
183
329
from datasets import load_dataset
184
330
185
-
# Select first 10 classes to reduce the dataset size and the training time.
186
-
train_size = 10 * 750
187
-
val_size = 10 * 250
331
+
# Select first 20 classes to reduce the dataset size and the training time.
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
548
695
549
696
probas = nnx.softmax(preds, axis=1)
550
697
pred_labels = probas.argmax(axis=1)
@@ -567,10 +714,11 @@ for i in range(num_samples):
567
714
568
715
## Further reading
569
716
570
-
In this tutorial we implemented from scratch Vision Transformer model and trained it on a subset of Food 101 dataset. Trained model shows 67% classification accuracy. Next steps could be to finetune hyperparameters like the learning rate and train for more epochs.
717
+
In this tutorial we implemented from scratch the Vision Transformer model and finetuned it on a subset of Food 101 dataset. The trained model shows almost perfect classification accuracy: 95%.
571
718
572
719
- Model checkpointing and exporting using [Orbax](https://orbax.readthedocs.io/en/latest/).
573
720
- Optimizers and the learning rate scheduling using [Optax](https://optax.readthedocs.io/en/latest/).
721
+
- Freezing model's parameters using trainable parameters filtering: [example 1](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) and [example 2](https://github.com/google/flax/issues/4167#issuecomment-2324245208).
574
722
- Other Computer Vision tutorials in [jax-ai-stack](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html).
0 commit comments