Skip to content

Update tutorial: Porting PyTorch model to JAX #220

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 54 additions & 50 deletions docs/source/JAX_porting_PyTorch_model.ipynb

Large diffs are not rendered by default.

38 changes: 21 additions & 17 deletions docs/source/JAX_porting_PyTorch_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,22 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)

**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**
In this tutorial, we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/). For an introduction to these tools, check out [JAX neural net basics](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) and [JAX for PyTorch users](https://docs.jaxstack.ai/en/latest/JAX_for_PyTorch_users.html).

In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`.
Specifically, we will port a PyTorch computer-vision model trained to classify images. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet ([MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697)).

```{code-cell} ipython3
!pip install -Uq flax treescope
```
Flax provides an API very similar to the PyTorch `torch.nn` module, which makes porting PyTorch models rather straightforward.
First, we'll set up the model using TorchVision, and briefly explore the model's architecture and the blocks we need to port. Next, we'll define equivalent blocks and the whole model using Flax. After that, we'll port the weights. Finally, we'll run some tests to ensure the correctness of the ported model.

**Note: On Colab, we recommend running this tutorial on a T4 GPU instance. On Kaggle, we recommend a T4x2 or P100 instance.**

Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).
+++

Let's start by installing TorchVision, and importing JAX and Flax.

First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model.
```{code-cell} ipython3
#!pip install -Uq torchvision
```

```{code-cell} ipython3
import jax
Expand All @@ -52,7 +57,7 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1)
We can use `flax.nnx.display` to display the model's architecture:

```{code-cell} ipython3
# nnx.display(torch_model)
nnx.display(torch_model)
```

We can see that there are four MaxViT blocks in the model and each block contains:
Expand Down Expand Up @@ -218,8 +223,9 @@ Please note some differences between `torch.nn` and Flax when porting models:

+++

Below we implement one by one all the modules from the above list and add simple forward pass checks.
Let's first implement equivalent of `nn.Identity`.
Below, we'll implement all the modules from the above list and add forward pass checks.

Before that, let's implement an equivalent of `nn.Identity`.

```{code-cell} ipython3
class Identity(nnx.Module):
Expand Down Expand Up @@ -1123,7 +1129,7 @@ class MaxVit(nnx.Module):
module.kernel.value = normal_initializer(
rngs(), module.kernel.value.shape, module.kernel.value.dtype
)
if module.bias.value is not None:
if module.bias is not None:
module.bias.value = jnp.zeros(
module.bias.value.shape, dtype=module.bias.value.dtype
)
Expand Down Expand Up @@ -1268,12 +1274,13 @@ class Torch2Flax:

torch_value = getattr(torch_nn_module, torch_key)
nnx_param = getattr(nnx_module, nnx_key)
assert nnx_param is not None, (torch_key, nnx_key, nnx_module)

if torch_value is None:
assert nnx_param.value is None, nnx_param
assert nnx_param is None, nnx_param
continue

assert nnx_param is not None, (torch_key, nnx_key, nnx_module)

params_transform = module_mapping_info.get("params_transform", Torch2Flax.default_params_transform)
torch_value = params_transform(torch_key, torch_value)

Expand Down Expand Up @@ -1627,7 +1634,4 @@ cosine_dist = (expected * flax_output).sum() / (jnp.linalg.norm(flax_output) * j
cosine_dist
```

## Further reading

- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)
- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)
For more Flax examples, check out [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html).