Skip to content

Commit b33a446

Browse files
Small tweaks (#152)
* Updates in Colab (small tweaks) * More tweaks * Oops, removing second Colab button. * Following the instructions for once!
1 parent 6e8aeaf commit b33a446

File tree

2 files changed

+179
-105
lines changed

2 files changed

+179
-105
lines changed

docs/source/JAX_porting_PyTorch_model.ipynb

Lines changed: 157 additions & 97 deletions
Large diffs are not rendered by default.

docs/source/JAX_porting_PyTorch_model.md

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,14 @@ kernelspec:
1616

1717
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)
1818

19+
**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**
20+
1921
In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.
2022

23+
```{code-cell} ipython3
24+
!pip install -Uq flax treescope
25+
```
26+
2127
Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).
2228

2329
First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model.
@@ -44,11 +50,10 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)
4450
```
4551

4652
We can use `flax.nnx.display` to display the model's architecture:
47-
```python
48-
nnx.display(torch_model)
49-
```
5053

51-
+++
54+
```{code-cell} ipython3
55+
# nnx.display(torch_model)
56+
```
5257

5358
We can see that there are four MaxViT blocks in the model and each block contains:
5459
- MaxViT layers: two layers for blocks 0, 1, 3 and five layers for the block 4
@@ -81,9 +86,18 @@ print(output.shape) # (2, 1000)
8186

8287
We can download an image of a Pembroke Corgy dog from [TorchVision's gallery](https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true) together with [ImageNet classes dictionary](https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json):
8388

84-
```bash
85-
wget "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
86-
wget "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
89+
```{code-cell} ipython3
90+
%%bash
91+
if [ -f "dog1.jpg" ]; then
92+
echo "dog1.jpg already exists."
93+
else
94+
wget -nv "https://github.com/pytorch/vision/blob/main/gallery/assets/dog1.jpg?raw=true" -O dog1.jpg
95+
fi
96+
if [ -f "imagenet_class_index.json" ]; then
97+
echo "imagenet_class_index.json already exists."
98+
else
99+
wget -nv "https://raw.githubusercontent.com/pytorch/vision/refs/heads/main/gallery/assets/imagenet_class_index.json" -O imagenet_class_index.json
100+
fi
87101
```
88102

89103
```{code-cell} ipython3
@@ -1615,5 +1629,5 @@ cosine_dist
16151629

16161630
## Further reading
16171631

1618-
- [Flax documentation: Core Exampels](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
1632+
- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
16191633
- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html)

0 commit comments

Comments
 (0)