From cdf96b1b82dbed6a1981df11b5b5335e7cdf7cbd Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Wed, 25 Jun 2025 17:18:25 +0200 Subject: [PATCH 1/2] Update porting pytorch model to jax --- docs/source/JAX_porting_PyTorch_model.ipynb | 107 ++++++++++---------- docs/source/JAX_porting_PyTorch_model.md | 39 +++---- 2 files changed, 78 insertions(+), 68 deletions(-) diff --git a/docs/source/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb index ce928b3..27a61d2 100644 --- a/docs/source/JAX_porting_PyTorch_model.ipynb +++ b/docs/source/JAX_porting_PyTorch_model.ipynb @@ -1,56 +1,47 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", - "id": "b69996dc-49af-4a0e-a4e6-36d81b51f2b4", + "id": "f3a4e607-b508-4b1a-9f44-be1fc93ed2e5", "metadata": {}, "source": [ "# Porting a PyTorch model to JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)\n", "\n", - "**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**\n", + "In this tutorial, we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/). For an introduction to these tools, check out [JAX neural net basics](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) and [JAX for PyTorch users](https://docs.jaxstack.ai/en/latest/JAX_for_PyTorch_users.html).\n", "\n", - "In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`." + "Specifically, we will port a PyTorch computer-vision model trained to classify images. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet ([MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697)).\n", + "\n", + "Flax provides an API very similar to the PyTorch `torch.nn` module, which makes porting PyTorch models rather straightforward. \n", + "First, we'll set up the model using TorchVision, and briefly explore the model's architecture and the blocks we need to port. Next, we'll define equivalent blocks and the whole model using Flax. After that, we'll port the weights. Finally, we'll run some tests to ensure the correctness of the ported model.\n", + "\n", + "**Note: On Colab, we recommend running this tutorial on a T4 GPU instance. On Kaggle, we recommend a T4x2 or P100 instance.**" ] }, { - "cell_type": "code", - "execution_count": 1, - "id": "NHqB3sNbrygd", + "cell_type": "markdown", + "id": "73dedf50-3215-4f91-845b-8b3f9c9d3e0c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", - "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m419.8/424.2 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.2/424.2 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/175.6 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.6/175.6 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h" - ] - } - ], "source": [ - "!pip install -Uq flax treescope" + "Let's start by installing TorchVision, and importing JAX and Flax." ] }, { - "cell_type": "markdown", - "id": "ABCg5TvPr1pm", + "cell_type": "code", + "execution_count": 1, + "id": "e662b199-6f1c-40a7-a1a8-8b6b3f709798", "metadata": {}, + "outputs": [], "source": [ - "Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).\n", - "\n", - "First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model." + "#!pip install -Uq torchvision" ] }, { "cell_type": "code", "execution_count": 2, - "id": "38504f77-4150-47bd-9cf9-3116fe370746", + "id": "8043268f-ce1c-403f-bd00-d8dede56b903", "metadata": {}, "outputs": [], "source": [ @@ -79,18 +70,7 @@ "execution_count": 3, "id": "9b1be406-d21c-410d-a2ac-9bd690e5ad60", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", - "Downloading: \"https://download.pytorch.org/models/maxvit_t-bc5ab103.pth\" to /root/.cache/torch/hub/checkpoints/maxvit_t-bc5ab103.pth\n", - "100%|██████████| 119M/119M [00:02<00:00, 53.9MB/s]\n" - ] - } - ], + "outputs": [], "source": [ "from torchvision.models import maxvit_t, MaxVit_T_Weights\n", "\n", @@ -110,9 +90,34 @@ "execution_count": 4, "id": "sZ9x7NpHtBcx", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# nnx.display(torch_model)" + "nnx.display(torch_model)" ] }, { @@ -422,8 +427,9 @@ "id": "8d7e3479-bffe-4cb6-81e1-ed8f972c5bf0", "metadata": {}, "source": [ - "Below we implement one by one all the modules from the above list and add simple forward pass checks.\n", - "Let's first implement equivalent of `nn.Identity`." + "Below, we'll implement all the modules from the above list and add forward pass checks.\n", + "\n", + "Before that, let's implement an equivalent of `nn.Identity`." ] }, { @@ -1592,7 +1598,7 @@ " module.kernel.value = normal_initializer(\n", " rngs(), module.kernel.value.shape, module.kernel.value.dtype\n", " )\n", - " if module.bias.value is not None:\n", + " if module.bias is not None:\n", " module.bias.value = jnp.zeros(\n", " module.bias.value.shape, dtype=module.bias.value.dtype\n", " )\n", @@ -1773,12 +1779,13 @@ "\n", " torch_value = getattr(torch_nn_module, torch_key)\n", " nnx_param = getattr(nnx_module, nnx_key)\n", - " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", "\n", " if torch_value is None:\n", - " assert nnx_param.value is None, nnx_param\n", + " assert nnx_param is None, nnx_param\n", " continue\n", "\n", + " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", + "\n", " params_transform = module_mapping_info.get(\"params_transform\", Torch2Flax.default_params_transform)\n", " torch_value = params_transform(torch_key, torch_value)\n", "\n", @@ -2268,10 +2275,7 @@ "id": "65e57aa6-1572-4805-9207-bc8a5f9f3ab1", "metadata": {}, "source": [ - "## Further reading\n", - "\n", - "- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)\n", - "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)" + "For more Flax examples, check out [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)." ] } ], @@ -2282,6 +2286,7 @@ "provenance": [] }, "jupytext": { + "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { @@ -2299,7 +2304,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.9" + "version": "3.12.11" } }, "nbformat": 4, diff --git a/docs/source/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md index 2c829a9..3f96e53 100644 --- a/docs/source/JAX_porting_PyTorch_model.md +++ b/docs/source/JAX_porting_PyTorch_model.md @@ -1,5 +1,6 @@ --- jupytext: + default_lexer: ipython3 formats: ipynb,md:myst text_representation: extension: .md @@ -16,17 +17,22 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb) -**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.** +In this tutorial, we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/). For an introduction to these tools, check out [JAX neural net basics](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) and [JAX for PyTorch users](https://docs.jaxstack.ai/en/latest/JAX_for_PyTorch_users.html). -In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`. +Specifically, we will port a PyTorch computer-vision model trained to classify images. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet ([MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697)). -```{code-cell} ipython3 -!pip install -Uq flax treescope -``` +Flax provides an API very similar to the PyTorch `torch.nn` module, which makes porting PyTorch models rather straightforward. +First, we'll set up the model using TorchVision, and briefly explore the model's architecture and the blocks we need to port. Next, we'll define equivalent blocks and the whole model using Flax. After that, we'll port the weights. Finally, we'll run some tests to ensure the correctness of the ported model. + +**Note: On Colab, we recommend running this tutorial on a T4 GPU instance. On Kaggle, we recommend a T4x2 or P100 instance.** -Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697). ++++ + +Let's start by installing TorchVision, and importing JAX and Flax. -First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model. +```{code-cell} ipython3 +#!pip install -Uq torchvision +``` ```{code-cell} ipython3 import jax @@ -52,7 +58,7 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1) We can use `flax.nnx.display` to display the model's architecture: ```{code-cell} ipython3 -# nnx.display(torch_model) +nnx.display(torch_model) ``` We can see that there are four MaxViT blocks in the model and each block contains: @@ -218,8 +224,9 @@ Please note some differences between `torch.nn` and Flax when porting models: +++ -Below we implement one by one all the modules from the above list and add simple forward pass checks. -Let's first implement equivalent of `nn.Identity`. +Below, we'll implement all the modules from the above list and add forward pass checks. + +Before that, let's implement an equivalent of `nn.Identity`. ```{code-cell} ipython3 class Identity(nnx.Module): @@ -1123,7 +1130,7 @@ class MaxVit(nnx.Module): module.kernel.value = normal_initializer( rngs(), module.kernel.value.shape, module.kernel.value.dtype ) - if module.bias.value is not None: + if module.bias is not None: module.bias.value = jnp.zeros( module.bias.value.shape, dtype=module.bias.value.dtype ) @@ -1268,12 +1275,13 @@ class Torch2Flax: torch_value = getattr(torch_nn_module, torch_key) nnx_param = getattr(nnx_module, nnx_key) - assert nnx_param is not None, (torch_key, nnx_key, nnx_module) if torch_value is None: - assert nnx_param.value is None, nnx_param + assert nnx_param is None, nnx_param continue + assert nnx_param is not None, (torch_key, nnx_key, nnx_module) + params_transform = module_mapping_info.get("params_transform", Torch2Flax.default_params_transform) torch_value = params_transform(torch_key, torch_value) @@ -1627,7 +1635,4 @@ cosine_dist = (expected * flax_output).sum() / (jnp.linalg.norm(flax_output) * j cosine_dist ``` -## Further reading - -- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html) -- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html) +For more Flax examples, check out [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html). From 8dfc27367c5994e5b8c5192f1430dd58c021d8f4 Mon Sep 17 00:00:00 2001 From: Pavithra Eswaramoorthy Date: Wed, 25 Jun 2025 17:26:27 +0200 Subject: [PATCH 2/2] Linting fixes --- docs/source/JAX_porting_PyTorch_model.ipynb | 3 +-- docs/source/JAX_porting_PyTorch_model.md | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/docs/source/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb index 27a61d2..1cf083d 100644 --- a/docs/source/JAX_porting_PyTorch_model.ipynb +++ b/docs/source/JAX_porting_PyTorch_model.ipynb @@ -2286,7 +2286,6 @@ "provenance": [] }, "jupytext": { - "default_lexer": "ipython3", "formats": "ipynb,md:myst" }, "kernelspec": { @@ -2304,7 +2303,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.11" + "version": "3.11.9" } }, "nbformat": 4, diff --git a/docs/source/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md index 3f96e53..34d5b58 100644 --- a/docs/source/JAX_porting_PyTorch_model.md +++ b/docs/source/JAX_porting_PyTorch_model.md @@ -1,6 +1,5 @@ --- jupytext: - default_lexer: ipython3 formats: ipynb,md:myst text_representation: extension: .md