diff --git a/docs/source/JAX_porting_PyTorch_model.ipynb b/docs/source/JAX_porting_PyTorch_model.ipynb index ce928b3..1cf083d 100644 --- a/docs/source/JAX_porting_PyTorch_model.ipynb +++ b/docs/source/JAX_porting_PyTorch_model.ipynb @@ -1,56 +1,47 @@ { "cells": [ { + "attachments": {}, "cell_type": "markdown", - "id": "b69996dc-49af-4a0e-a4e6-36d81b51f2b4", + "id": "f3a4e607-b508-4b1a-9f44-be1fc93ed2e5", "metadata": {}, "source": [ "# Porting a PyTorch model to JAX\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb)\n", "\n", - "**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.**\n", + "In this tutorial, we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/). For an introduction to these tools, check out [JAX neural net basics](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) and [JAX for PyTorch users](https://docs.jaxstack.ai/en/latest/JAX_for_PyTorch_users.html).\n", "\n", - "In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`." + "Specifically, we will port a PyTorch computer-vision model trained to classify images. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet ([MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697)).\n", + "\n", + "Flax provides an API very similar to the PyTorch `torch.nn` module, which makes porting PyTorch models rather straightforward. \n", + "First, we'll set up the model using TorchVision, and briefly explore the model's architecture and the blocks we need to port. Next, we'll define equivalent blocks and the whole model using Flax. After that, we'll port the weights. Finally, we'll run some tests to ensure the correctness of the ported model.\n", + "\n", + "**Note: On Colab, we recommend running this tutorial on a T4 GPU instance. On Kaggle, we recommend a T4x2 or P100 instance.**" ] }, { - "cell_type": "code", - "execution_count": 1, - "id": "NHqB3sNbrygd", + "cell_type": "markdown", + "id": "73dedf50-3215-4f91-845b-8b3f9c9d3e0c", "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/424.2 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", - "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m \u001b[32m419.8/424.2 kB\u001b[0m \u001b[31m14.1 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\r", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m424.2/424.2 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h\u001b[?25l \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m0.0/175.6 kB\u001b[0m \u001b[31m?\u001b[0m eta \u001b[36m-:--:--\u001b[0m\r", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m175.6/175.6 kB\u001b[0m \u001b[31m10.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25h" - ] - } - ], "source": [ - "!pip install -Uq flax treescope" + "Let's start by installing TorchVision, and importing JAX and Flax." ] }, { - "cell_type": "markdown", - "id": "ABCg5TvPr1pm", + "cell_type": "code", + "execution_count": 1, + "id": "e662b199-6f1c-40a7-a1a8-8b6b3f709798", "metadata": {}, + "outputs": [], "source": [ - "Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697).\n", - "\n", - "First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model." + "#!pip install -Uq torchvision" ] }, { "cell_type": "code", "execution_count": 2, - "id": "38504f77-4150-47bd-9cf9-3116fe370746", + "id": "8043268f-ce1c-403f-bd00-d8dede56b903", "metadata": {}, "outputs": [], "source": [ @@ -79,18 +70,7 @@ "execution_count": 3, "id": "9b1be406-d21c-410d-a2ac-9bd690e5ad60", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/torch/functional.py:534: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at ../aten/src/ATen/native/TensorShape.cpp:3595.)\n", - " return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]\n", - "Downloading: \"https://download.pytorch.org/models/maxvit_t-bc5ab103.pth\" to /root/.cache/torch/hub/checkpoints/maxvit_t-bc5ab103.pth\n", - "100%|██████████| 119M/119M [00:02<00:00, 53.9MB/s]\n" - ] - } - ], + "outputs": [], "source": [ "from torchvision.models import maxvit_t, MaxVit_T_Weights\n", "\n", @@ -110,9 +90,34 @@ "execution_count": 4, "id": "sZ9x7NpHtBcx", "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "# nnx.display(torch_model)" + "nnx.display(torch_model)" ] }, { @@ -422,8 +427,9 @@ "id": "8d7e3479-bffe-4cb6-81e1-ed8f972c5bf0", "metadata": {}, "source": [ - "Below we implement one by one all the modules from the above list and add simple forward pass checks.\n", - "Let's first implement equivalent of `nn.Identity`." + "Below, we'll implement all the modules from the above list and add forward pass checks.\n", + "\n", + "Before that, let's implement an equivalent of `nn.Identity`." ] }, { @@ -1592,7 +1598,7 @@ " module.kernel.value = normal_initializer(\n", " rngs(), module.kernel.value.shape, module.kernel.value.dtype\n", " )\n", - " if module.bias.value is not None:\n", + " if module.bias is not None:\n", " module.bias.value = jnp.zeros(\n", " module.bias.value.shape, dtype=module.bias.value.dtype\n", " )\n", @@ -1773,12 +1779,13 @@ "\n", " torch_value = getattr(torch_nn_module, torch_key)\n", " nnx_param = getattr(nnx_module, nnx_key)\n", - " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", "\n", " if torch_value is None:\n", - " assert nnx_param.value is None, nnx_param\n", + " assert nnx_param is None, nnx_param\n", " continue\n", "\n", + " assert nnx_param is not None, (torch_key, nnx_key, nnx_module)\n", + "\n", " params_transform = module_mapping_info.get(\"params_transform\", Torch2Flax.default_params_transform)\n", " torch_value = params_transform(torch_key, torch_value)\n", "\n", @@ -2268,10 +2275,7 @@ "id": "65e57aa6-1572-4805-9207-bc8a5f9f3ab1", "metadata": {}, "source": [ - "## Further reading\n", - "\n", - "- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)\n", - "- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html)" + "For more Flax examples, check out [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html)." ] } ], diff --git a/docs/source/JAX_porting_PyTorch_model.md b/docs/source/JAX_porting_PyTorch_model.md index 2c829a9..34d5b58 100644 --- a/docs/source/JAX_porting_PyTorch_model.md +++ b/docs/source/JAX_porting_PyTorch_model.md @@ -16,17 +16,22 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_porting_PyTorch_model.ipynb) -**Note: On Colab we recommend running this on a T4 GPU instance. On Kaggle we recommend a T4x2 or P100 instance.** +In this tutorial, we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/). For an introduction to these tools, check out [JAX neural net basics](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) and [JAX for PyTorch users](https://docs.jaxstack.ai/en/latest/JAX_for_PyTorch_users.html). -In this tutorial we will learn how to port a PyTorch model to JAX and [Flax](https://flax.readthedocs.io/en/latest/nnx_basics.html). Flax provides an API very similar to the PyTorch `torch.nn` module and porting PyTorch models is rather straightforward. To install Flax, we can simply execute the following command: `pip install -U flax treescope`. +Specifically, we will port a PyTorch computer-vision model trained to classify images. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet ([MaxViT: Multi-Axis Vision Transformer](https://arxiv.org/abs/2204.01697)). -```{code-cell} ipython3 -!pip install -Uq flax treescope -``` +Flax provides an API very similar to the PyTorch `torch.nn` module, which makes porting PyTorch models rather straightforward. +First, we'll set up the model using TorchVision, and briefly explore the model's architecture and the blocks we need to port. Next, we'll define equivalent blocks and the whole model using Flax. After that, we'll port the weights. Finally, we'll run some tests to ensure the correctness of the ported model. + +**Note: On Colab, we recommend running this tutorial on a T4 GPU instance. On Kaggle, we recommend a T4x2 or P100 instance.** -Say we have a trained PyTorch computer-vision model to classify images that we would like to port to JAX. We will use [`TorchVision`](https://pytorch.org/vision/stable/index.html) to provide a [MaxVit](https://pytorch.org/vision/stable/models/maxvit.html) model trained on ImageNet (MaxViT: Multi-Axis Vision Transformer, https://arxiv.org/abs/2204.01697). ++++ + +Let's start by installing TorchVision, and importing JAX and Flax. -First, we set up the model using TorchVision and explore briefly the model's architecture and the blocks we need to port. Next, we define equivalent blocks and the whole model using Flax. After that, we port the weights. Finally, we run some tests to ensure the correctness of the ported model. +```{code-cell} ipython3 +#!pip install -Uq torchvision +``` ```{code-cell} ipython3 import jax @@ -52,7 +57,7 @@ torch_model = maxvit_t(weights=MaxVit_T_Weights.IMAGENET1K_V1) We can use `flax.nnx.display` to display the model's architecture: ```{code-cell} ipython3 -# nnx.display(torch_model) +nnx.display(torch_model) ``` We can see that there are four MaxViT blocks in the model and each block contains: @@ -218,8 +223,9 @@ Please note some differences between `torch.nn` and Flax when porting models: +++ -Below we implement one by one all the modules from the above list and add simple forward pass checks. -Let's first implement equivalent of `nn.Identity`. +Below, we'll implement all the modules from the above list and add forward pass checks. + +Before that, let's implement an equivalent of `nn.Identity`. ```{code-cell} ipython3 class Identity(nnx.Module): @@ -1123,7 +1129,7 @@ class MaxVit(nnx.Module): module.kernel.value = normal_initializer( rngs(), module.kernel.value.shape, module.kernel.value.dtype ) - if module.bias.value is not None: + if module.bias is not None: module.bias.value = jnp.zeros( module.bias.value.shape, dtype=module.bias.value.dtype ) @@ -1268,12 +1274,13 @@ class Torch2Flax: torch_value = getattr(torch_nn_module, torch_key) nnx_param = getattr(nnx_module, nnx_key) - assert nnx_param is not None, (torch_key, nnx_key, nnx_module) if torch_value is None: - assert nnx_param.value is None, nnx_param + assert nnx_param is None, nnx_param continue + assert nnx_param is not None, (torch_key, nnx_key, nnx_module) + params_transform = module_mapping_info.get("params_transform", Torch2Flax.default_params_transform) torch_value = params_transform(torch_key, torch_value) @@ -1627,7 +1634,4 @@ cosine_dist = (expected * flax_output).sum() / (jnp.linalg.norm(flax_output) * j cosine_dist ``` -## Further reading - -- [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html) -- [JAX AI Stack tutorials](https://jax-ai-stack.readthedocs.io/en/latest/getting_started.html) +For more Flax examples, check out [Flax documentation: Core Examples](https://flax.readthedocs.io/en/latest/examples/core_examples.html).