From a4cc28f289ae8e6ba4121cb5e3ec64c5dd190f2d Mon Sep 17 00:00:00 2001 From: Skye Wanderman-Milne Date: Fri, 23 May 2025 11:22:48 -0700 Subject: [PATCH] Rename `getting_started_with_jax_for_AI` files and URL to `neural_net_basics` --- docs/source/JAX_Vision_transformer.ipynb | 2 +- docs/source/JAX_Vision_transformer.md | 2 +- docs/source/JAX_for_LLM_pretraining.ipynb | 2 +- docs/source/JAX_for_LLM_pretraining.md | 2 +- docs/source/JAX_transformer_text_classification.ipynb | 2 +- docs/source/JAX_transformer_text_classification.md | 2 +- docs/source/JAX_visualizing_models_metrics.ipynb | 2 +- docs/source/JAX_visualizing_models_metrics.md | 2 +- docs/source/conf.py | 2 +- docs/source/digits_diffusion_model.ipynb | 2 +- docs/source/digits_diffusion_model.md | 2 +- docs/source/digits_vae.ipynb | 10 +++++----- docs/source/digits_vae.md | 10 +++++----- docs/source/getting_started.md | 2 +- ...d_with_jax_for_AI.ipynb => neural_net_basics.ipynb} | 2 +- ...started_with_jax_for_AI.md => neural_net_basics.md} | 2 +- 16 files changed, 24 insertions(+), 24 deletions(-) rename docs/source/{getting_started_with_jax_for_AI.ipynb => neural_net_basics.ipynb} (99%) rename docs/source/{getting_started_with_jax_for_AI.md => neural_net_basics.md} (99%) diff --git a/docs/source/JAX_Vision_transformer.ipynb b/docs/source/JAX_Vision_transformer.ipynb index dcfc926..cf56fcc 100644 --- a/docs/source/JAX_Vision_transformer.ipynb +++ b/docs/source/JAX_Vision_transformer.ipynb @@ -13,7 +13,7 @@ "\n", "This tutorial draws inspiration from the HuggingFace [Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification). The original JAX-based implementation of the ViT model can be found in the [google-research/vision_transformer](https://github.com/google-research/vision_transformer/) GitHub repository.\n", "\n", - "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX." + "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX." ] }, { diff --git a/docs/source/JAX_Vision_transformer.md b/docs/source/JAX_Vision_transformer.md index ee39d76..94d1a88 100644 --- a/docs/source/JAX_Vision_transformer.md +++ b/docs/source/JAX_Vision_transformer.md @@ -20,7 +20,7 @@ This tutorial guides you through developing and training a Vision Transformer (V This tutorial draws inspiration from the HuggingFace [Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification). The original JAX-based implementation of the ViT model can be found in the [google-research/vision_transformer](https://github.com/google-research/vision_transformer/) GitHub repository. -If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX. +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX. +++ diff --git a/docs/source/JAX_for_LLM_pretraining.ipynb b/docs/source/JAX_for_LLM_pretraining.ipynb index 9d1b013..3cf91c1 100644 --- a/docs/source/JAX_for_LLM_pretraining.ipynb +++ b/docs/source/JAX_for_LLM_pretraining.ipynb @@ -20,7 +20,7 @@ "- Train the model on Google Colab’s Cloud TPU v2\n", "- Profile for hyperparameter tuning\n", "\n", - "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)." + "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)." ] }, { diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 267729b..0450864 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -27,7 +27,7 @@ Here, you will learn how to: - Train the model on Google Colab’s Cloud TPU v2 - Profile for hyperparameter tuning -If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). +++ {"id": "hTmz5Cbco7n_"} diff --git a/docs/source/JAX_transformer_text_classification.ipynb b/docs/source/JAX_transformer_text_classification.ipynb index 8c3d625..c152daa 100644 --- a/docs/source/JAX_transformer_text_classification.ipynb +++ b/docs/source/JAX_transformer_text_classification.ipynb @@ -18,7 +18,7 @@ "- Train the model.\n", "- Evaluate the model with an example.\n", "\n", - "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n", + "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n", "\n", "## Setup\n", "\n", diff --git a/docs/source/JAX_transformer_text_classification.md b/docs/source/JAX_transformer_text_classification.md index 260e87a..1307acf 100644 --- a/docs/source/JAX_transformer_text_classification.md +++ b/docs/source/JAX_transformer_text_classification.md @@ -26,7 +26,7 @@ Here, you will learn how to: - Train the model. - Evaluate the model with an example. -If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). ## Setup diff --git a/docs/source/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb index 6ece7e0..5280d49 100644 --- a/docs/source/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -13,7 +13,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.\n", + "To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) - if you haven't read that yet and want the primer, start there before returning.\n", "\n", "All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them." ] diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 2e27bc3..7819d42 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -18,7 +18,7 @@ kernelspec: +++ -To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning. +To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) - if you haven't read that yet and want the primer, start there before returning. All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them. diff --git a/docs/source/conf.py b/docs/source/conf.py index bff261f..36bc79a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -52,7 +52,7 @@ 'build/jupyter_execute', # Exclude markdown sources for notebooks: 'digits_vae.md', - 'getting_started_with_jax_for_AI.md', + 'neural_net_basics.md', 'JAX_for_PyTorch_users.md', 'JAX_porting_PyTorch_model.md', 'digits_diffusion_model.md', diff --git a/docs/source/digits_diffusion_model.ipynb b/docs/source/digits_diffusion_model.ipynb index 9c4c2ca..4b4088a 100644 --- a/docs/source/digits_diffusion_model.ipynb +++ b/docs/source/digits_diffusion_model.ipynb @@ -20,7 +20,7 @@ "- Train the model (with Google Colab’s Cloud TPU v2)\n", "- Visualize and track the model’s progress\n", "\n", - "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX." + "If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX." ] }, { diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 06e1b60..9bb021b 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -27,7 +27,7 @@ In this tutorial, you'll learn how to: - Train the model (with Google Colab’s Cloud TPU v2) - Visualize and track the model’s progress -If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX. +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX. +++ {"id": "gwaaMmjXt7n7"} diff --git a/docs/source/digits_vae.ipynb b/docs/source/digits_vae.ipynb index 7c431ab..6adbad5 100644 --- a/docs/source/digits_vae.ipynb +++ b/docs/source/digits_vae.ipynb @@ -10,9 +10,9 @@ "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)\n", "\n", - "This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.\n", + "This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.\n", "\n", - "If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`." + "If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`." ] }, { @@ -23,7 +23,7 @@ "source": [ "## Loading the data\n", "\n", - "As [before](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:" + "As [before](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:" ] }, { @@ -64,7 +64,7 @@ "id": "2_Q16JRyrW7V" }, "source": [ - "The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial." + "The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html#loading-the-data) in the previous tutorial." ] }, { @@ -75,7 +75,7 @@ "source": [ "## Defining the VAE with Flax\n", "\n", - "[Previously](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:" + "[Previously](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:" ] }, { diff --git a/docs/source/digits_vae.md b/docs/source/digits_vae.md index 05f14c3..2732a2a 100644 --- a/docs/source/digits_vae.md +++ b/docs/source/digits_vae.md @@ -17,15 +17,15 @@ kernelspec: [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb) -This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training. +This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training. -If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`. +If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`. +++ {"id": "k19povzxp7hS"} ## Loading the data -As [before](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset: +As [before](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset: ```{code-cell} :id: aIwDAfS6PtFh @@ -47,13 +47,13 @@ print(f"{images_test.shape=}") +++ {"id": "2_Q16JRyrW7V"} -The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial. +The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html#loading-the-data) in the previous tutorial. +++ {"id": "Z9TPYqipPyBp"} ## Defining the VAE with Flax -[Previously](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this: +[Previously](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this: ```{code-cell} :id: HNlg-ydpr5yH diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md index 4e5d83e..eef4f8c 100644 --- a/docs/source/getting_started.md +++ b/docs/source/getting_started.md @@ -20,7 +20,7 @@ After working through this content, you may wish to visit the [JAX documentation ```{toctree} :maxdepth: 1 -getting_started_with_jax_for_AI +neural_net_basics digits_vae digits_diffusion_model ``` diff --git a/docs/source/getting_started_with_jax_for_AI.ipynb b/docs/source/neural_net_basics.ipynb similarity index 99% rename from docs/source/getting_started_with_jax_for_AI.ipynb rename to docs/source/neural_net_basics.ipynb index d221fed..2b783fb 100644 --- a/docs/source/getting_started_with_jax_for_AI.ipynb +++ b/docs/source/neural_net_basics.ipynb @@ -8,7 +8,7 @@ "source": [ "# Part 1: JAX neural net basics\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)" + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/neural_net_basics.ipynb)" ] }, { diff --git a/docs/source/getting_started_with_jax_for_AI.md b/docs/source/neural_net_basics.md similarity index 99% rename from docs/source/getting_started_with_jax_for_AI.md rename to docs/source/neural_net_basics.md index 04cca88..0589656 100644 --- a/docs/source/getting_started_with_jax_for_AI.md +++ b/docs/source/neural_net_basics.md @@ -15,7 +15,7 @@ kernelspec: # Part 1: JAX neural net basics -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb) +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/neural_net_basics.ipynb) +++ {"id": "z7sAr0sderhh"}