Skip to content

Commit 4ec7510

Browse files
authored
Updated ViT tutorial to fine-tune the model on Food101 dataset (#116)
1 parent 0f09306 commit 4ec7510

File tree

2 files changed

+585
-954
lines changed

2 files changed

+585
-954
lines changed

docs/JAX_Vision_transformer.ipynb

Lines changed: 411 additions & 928 deletions
Large diffs are not rendered by default.

docs/JAX_Vision_transformer.md

Lines changed: 174 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@ kernelspec:
1414

1515
# Vision Transformer with JAX & FLAX
1616

17-
18-
In this tutorial we implement from scratch Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We will train this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
17+
In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
1918
This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification).
2019

2120
+++
@@ -27,9 +26,10 @@ We will need to install the following Python packages:
2726
- [TorchVision](https://pytorch.org/vision) will be used for image augmentations
2827
- [grain](https://github.com/google/grain/) will be be used for efficient data loading
2928
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
29+
- [Matplotlib](https://matplotlib.org/stable/) will be used for visualization purposes
3030

3131
```{code-cell} ipython3
32-
# !pip install -U datasets grain torchvision tqdm
32+
# !pip install -U datasets grain torchvision tqdm matplotlib
3333
# !pip install -U flax optax
3434
```
3535

@@ -98,7 +98,7 @@ class VisionTransformer(nnx.Module):
9898
TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
9999
for i in range(num_layers)
100100
])
101-
self.lnorm = nnx.LayerNorm(hidden_size, rngs=rngs)
101+
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)
102102
103103
# Classification head
104104
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
@@ -116,7 +116,7 @@ class VisionTransformer(nnx.Module):
116116
117117
# Encoder blocks
118118
x = self.encoder(embeddings)
119-
x = self.lnorm(x)
119+
x = self.final_norm(x)
120120
121121
# fetch the first token
122122
x = x[:, 0]
@@ -162,9 +162,155 @@ class TransformerEncoder(nnx.Module):
162162
return x
163163
164164
165-
# We use a configuration to make smaller model to reduce the training time
166-
x = jnp.ones((4, 120, 120, 3))
167-
model = VisionTransformer(num_classes=10, num_layers=4, num_heads=4, img_size=120, patch_size=8)
165+
x = jnp.ones((4, 224, 224, 3))
166+
model = VisionTransformer(num_classes=1000)
167+
y = model(x)
168+
print("Predictions shape: ", y.shape)
169+
```
170+
171+
Let's now load the weights pretrained on the ImageNet dataset using HuggingFace Transformers. We load all weights and check whether we have consistent results with the reference model.
172+
173+
```{code-cell} ipython3
174+
from transformers import FlaxViTForImageClassification
175+
176+
tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
177+
```
178+
179+
```{code-cell} ipython3
180+
def vit_inplace_copy_weights(*, src_model, dst_model):
181+
assert isinstance(src_model, FlaxViTForImageClassification)
182+
assert isinstance(dst_model, VisionTransformer)
183+
184+
tf_model_params = src_model.params
185+
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)
186+
187+
flax_model_params = nnx.state(dst_model, nnx.Param)
188+
flax_model_params_fstate = flax_model_params.flat_state()
189+
190+
params_name_mapping = {
191+
("cls_token",): ("vit", "embeddings", "cls_token"),
192+
("position_embeddings",): ("vit", "embeddings", "position_embeddings"),
193+
**{
194+
("patch_embeddings", x): ("vit", "embeddings", "patch_embeddings", "projection", x)
195+
for x in ["kernel", "bias"]
196+
},
197+
**{
198+
("encoder", "layers", i, "attn", y, x): (
199+
"vit", "encoder", "layer", str(i), "attention", "attention", y, x
200+
)
201+
for x in ["kernel", "bias"]
202+
for y in ["key", "value", "query"]
203+
for i in range(12)
204+
},
205+
**{
206+
("encoder", "layers", i, "attn", "out", x): (
207+
"vit", "encoder", "layer", str(i), "attention", "output", "dense", x
208+
)
209+
for x in ["kernel", "bias"]
210+
for i in range(12)
211+
},
212+
**{
213+
("encoder", "layers", i, "mlp", "layers", y1, x): (
214+
"vit", "encoder", "layer", str(i), y2, "dense", x
215+
)
216+
for x in ["kernel", "bias"]
217+
for y1, y2 in [(0, "intermediate"), (3, "output")]
218+
for i in range(12)
219+
},
220+
**{
221+
("encoder", "layers", i, y1, x): (
222+
"vit", "encoder", "layer", str(i), y2, x
223+
)
224+
for x in ["scale", "bias"]
225+
for y1, y2 in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
226+
for i in range(12)
227+
},
228+
**{
229+
("final_norm", x): ("vit", "layernorm", x)
230+
for x in ["scale", "bias"]
231+
},
232+
**{
233+
("classifier", x): ("classifier", x)
234+
for x in ["kernel", "bias"]
235+
}
236+
}
237+
238+
nonvisited = set(flax_model_params_fstate.keys())
239+
240+
for key1, key2 in params_name_mapping.items():
241+
assert key1 in flax_model_params_fstate, key1
242+
assert key2 in tf_model_params_fstate, (key1, key2)
243+
244+
nonvisited.remove(key1)
245+
246+
src_value = tf_model_params_fstate[key2]
247+
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
248+
shape = src_value.shape
249+
src_value = src_value.reshape((shape[0], 12, 64))
250+
251+
if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
252+
src_value = src_value.reshape((12, 64))
253+
254+
if key2[-4:] == ("attention", "output", "dense", "kernel"):
255+
shape = src_value.shape
256+
src_value = src_value.reshape((12, 64, shape[-1]))
257+
258+
dst_value = flax_model_params_fstate[key1]
259+
assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
260+
dst_value.value = src_value.copy()
261+
assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())
262+
263+
assert len(nonvisited) == 0, nonvisited
264+
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))
265+
266+
267+
vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
268+
```
269+
270+
Let's check the pretrained weights of our model and compare with the reference model results:
271+
272+
```{code-cell} ipython3
273+
import matplotlib.pyplot as plt
274+
from transformers import ViTImageProcessor
275+
from PIL import Image
276+
import requests
277+
278+
url = "https://farm2.staticflickr.com/1152/1151216944_1525126615_z.jpg"
279+
image = Image.open(requests.get(url, stream=True).raw)
280+
281+
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
282+
283+
inputs = processor(images=image, return_tensors="np")
284+
outputs = tf_model(**inputs)
285+
logits = outputs.logits
286+
287+
288+
model.eval()
289+
x = jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
290+
output = model(x)
291+
292+
# model predicts one of the 1000 ImageNet classes
293+
ref_class_idx = logits.argmax(-1).item()
294+
pred_class_idx = output.argmax(-1).item()
295+
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1
296+
297+
fig, axs = plt.subplots(1, 2, figsize=(12, 8))
298+
axs[0].set_title(
299+
f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
300+
)
301+
axs[0].imshow(image)
302+
axs[1].set_title(
303+
f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
304+
)
305+
axs[1].imshow(image)
306+
```
307+
308+
Now let's replace the classifier with a smaller fully-connected layer returning 20 classes instead of 1000:
309+
310+
```{code-cell} ipython3
311+
model.classifier = nnx.Linear(model.classifier.in_features, 20, rngs=nnx.Rngs(0))
312+
313+
x = jnp.ones((4, 224, 224, 3))
168314
y = model(x)
169315
print("Predictions shape: ", y.shape)
170316
```
@@ -177,19 +323,19 @@ In the following sections we set up a image classification dataset and train thi
177323

178324
In the this tutorial we use [Food 101](https://huggingface.co/datasets/ethz/food101) dataset which consists of 101 food categories, with 101,000 images. For each class, 250 manually reviewed test images are provided as well as 750 training images. On purpose, the training images were not cleaned, and thus still contain some amount of noise. This comes mostly in the form of intense colors and sometimes wrong labels. All images were rescaled to have a maximum side length of 512 pixels.
179325

180-
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 10 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
326+
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 20 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
181327

182328
```{code-cell} ipython3
183329
from datasets import load_dataset
184330
185-
# Select first 10 classes to reduce the dataset size and the training time.
186-
train_size = 10 * 750
187-
val_size = 10 * 250
331+
# Select first 20 classes to reduce the dataset size and the training time.
332+
train_size = 20 * 750
333+
val_size = 20 * 250
188334
189335
train_dataset = load_dataset("food101", split=f"train[:{train_size}]")
190336
val_dataset = load_dataset("food101", split=f"validation[:{val_size}]")
191337
192-
# Let's create labels mapping where we map current labels between 0 and 9
338+
# Let's create labels mapping where we map current labels between 0 and 19
193339
labels_mapping = {}
194340
index = 0
195341
for i in range(0, len(val_dataset), 250):
@@ -198,6 +344,7 @@ for i in range(0, len(val_dataset), 250):
198344
labels_mapping[label] = index
199345
index += 1
200346
347+
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
201348
202349
print("Training dataset size:", len(train_dataset))
203350
print("Validation dataset size:", len(val_dataset))
@@ -248,18 +395,19 @@ import numpy as np
248395
from torchvision.transforms import v2 as T
249396
250397
251-
img_size = 120
398+
img_size = 224
252399
253400
254401
def to_np_array(pil_image):
255402
return np.asarray(pil_image.convert("RGB"))
256403
257404
258405
def normalize(image):
259-
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
260-
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
261-
image = image.astype(np.float32) / 255.0
262-
return (image - mean) / std
406+
# Image preprocessing matches the one of pretrained ViT
407+
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
408+
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
409+
image = image.astype(np.float32) / 255.0
410+
return (image - mean) / std
263411
264412
265413
tv_train_transforms = T.Compose([
@@ -283,7 +431,7 @@ def get_transform(fn):
283431
batch["image"] = [
284432
fn(pil_image) for pil_image in batch["image"]
285433
]
286-
# map label index between 0 - 9
434+
# map label index between 0 - 19
287435
batch["label"] = [
288436
labels_mapping[label] for label in batch["label"]
289437
]
@@ -303,7 +451,7 @@ import grain.python as grain
303451
304452
305453
seed = 12
306-
train_batch_size = 64
454+
train_batch_size = 32
307455
val_batch_size = 2 * train_batch_size
308456
309457
@@ -363,15 +511,15 @@ print("Validation batch info:", val_batch["image"].shape, val_batch["image"].dty
363511
display_datapoints(
364512
*[(train_batch["image"][i], train_batch["label"][i]) for i in range(5)],
365513
tag="(Training) ",
366-
names_map=train_dataset.features["label"].names
514+
names_map={k: train_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
367515
)
368516
```
369517

370518
```{code-cell} ipython3
371519
display_datapoints(
372520
*[(val_batch["image"][i], val_batch["label"][i]) for i in range(5)],
373521
tag="(Validation) ",
374-
names_map=val_dataset.features["label"].names
522+
names_map={k: val_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
375523
)
376524
```
377525

@@ -382,8 +530,8 @@ We defined training and validation datasets and the model. In this section we wi
382530
```{code-cell} ipython3
383531
import optax
384532
385-
num_epochs = 50
386-
learning_rate = 0.005
533+
num_epochs = 3
534+
learning_rate = 0.001
387535
momentum = 0.8
388536
total_steps = len(train_dataset) // train_batch_size
389537
@@ -544,7 +692,6 @@ preds = model(test_images)
544692
```{code-cell} ipython3
545693
num_samples = len(test_indices)
546694
names_map = train_dataset.features["label"].names
547-
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}
548695
549696
probas = nnx.softmax(preds, axis=1)
550697
pred_labels = probas.argmax(axis=1)
@@ -567,10 +714,11 @@ for i in range(num_samples):
567714

568715
## Further reading
569716

570-
In this tutorial we implemented from scratch Vision Transformer model and trained it on a subset of Food 101 dataset. Trained model shows 67% classification accuracy. Next steps could be to finetune hyperparameters like the learning rate and train for more epochs.
717+
In this tutorial we implemented from scratch the Vision Transformer model and finetuned it on a subset of Food 101 dataset. The trained model shows almost perfect classification accuracy: 95%.
571718

572719
- Model checkpointing and exporting using [Orbax](https://orbax.readthedocs.io/en/latest/).
573720
- Optimizers and the learning rate scheduling using [Optax](https://optax.readthedocs.io/en/latest/).
721+
- Freezing model's parameters using trainable parameters filtering: [example 1](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) and [example 2](https://github.com/google/flax/issues/4167#issuecomment-2324245208).
574722
- Other Computer Vision tutorials in [jax-ai-stack](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html).
575723

576724
```{code-cell} ipython3

0 commit comments

Comments
 (0)