Skip to content

Rename getting_started_with_jax_for_AI files and URL to neural_net_basics #211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/source/JAX_Vision_transformer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"\n",
"This tutorial draws inspiration from the HuggingFace [Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification). The original JAX-based implementation of the ViT model can be found in the [google-research/vision_transformer](https://github.com/google-research/vision_transformer/) GitHub repository.\n",
"\n",
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX."
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ This tutorial guides you through developing and training a Vision Transformer (V

This tutorial draws inspiration from the HuggingFace [Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification). The original JAX-based implementation of the ViT model can be found in the [google-research/vision_transformer](https://github.com/google-research/vision_transformer/) GitHub repository.

If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX.
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX.

+++

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"- Train the model on Google Colab’s Cloud TPU v2\n",
"- Profile for hyperparameter tuning\n",
"\n",
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)."
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html)."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_for_LLM_pretraining.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Here, you will learn how to:
- Train the model on Google Colab’s Cloud TPU v2
- Profile for hyperparameter tuning

If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).

+++ {"id": "hTmz5Cbco7n_"}

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_transformer_text_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
"- Train the model.\n",
"- Evaluate the model with an example.\n",
"\n",
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n",
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).\n",
"\n",
"## Setup\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_transformer_text_classification.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Here, you will learn how to:
- Train the model.
- Evaluate the model with an example.

If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).

## Setup

Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_visualizing_models_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.\n",
"To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) - if you haven't read that yet and want the primer, start there before returning.\n",
"\n",
"All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them."
]
Expand Down
2 changes: 1 addition & 1 deletion docs/source/JAX_visualizing_models_metrics.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ kernelspec:

+++

To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.
To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html) - if you haven't read that yet and want the primer, start there before returning.

All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them.

Expand Down
2 changes: 1 addition & 1 deletion docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
'build/jupyter_execute',
# Exclude markdown sources for notebooks:
'digits_vae.md',
'getting_started_with_jax_for_AI.md',
'neural_net_basics.md',
'JAX_for_PyTorch_users.md',
'JAX_porting_PyTorch_model.md',
'digits_diffusion_model.md',
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_diffusion_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
"- Train the model (with Google Colab’s Cloud TPU v2)\n",
"- Visualize and track the model’s progress\n",
"\n",
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX."
"If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX."
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ In this tutorial, you'll learn how to:
- Train the model (with Google Colab’s Cloud TPU v2)
- Visualize and track the model’s progress

If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX.
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with Flax, Optax and JAX.

+++ {"id": "gwaaMmjXt7n7"}

Expand Down
10 changes: 5 additions & 5 deletions docs/source/digits_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)\n",
"\n",
"This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.\n",
"This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.\n",
"\n",
"If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`."
"If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`."
]
},
{
Expand All @@ -23,7 +23,7 @@
"source": [
"## Loading the data\n",
"\n",
"As [before](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:"
"As [before](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:"
]
},
{
Expand Down Expand Up @@ -64,7 +64,7 @@
"id": "2_Q16JRyrW7V"
},
"source": [
"The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial."
"The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html#loading-the-data) in the previous tutorial."
]
},
{
Expand All @@ -75,7 +75,7 @@
"source": [
"## Defining the VAE with Flax\n",
"\n",
"[Previously](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:"
"[Previously](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:"
]
},
{
Expand Down
10 changes: 5 additions & 5 deletions docs/source/digits_vae.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)

This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.
This tutorial explores a simplified version of a generative model called [Variational Autoencoder (VAE)](https://en.wikipedia.org/wiki/Variational_autoencoder) with [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset, and expands on what we learned in [Getting started with JAX](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html). Along the way, you'll learn more about how JAX's [JIT compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation) (`jax.jit`) actually works, and what this means for [debugging](https://jax.readthedocs.io/en/latest/debugging/index.html) [JAX programs](https://jax.readthedocs.io/en/latest/debugging.html), as we learn how to identify what can go wrong during model training.

If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.
If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which explains how to build a simple neural netwwork with Flax and Optax, and JAX's key features, including the NumPy-style interface with `jax.numpy`, JAX transformations for JIT compilation with `jax.jit`, automatic vectorization with `jax.vmap`, and automatic differentiation with `jax.grad`.

+++ {"id": "k19povzxp7hS"}

## Loading the data

As [before](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:
As [before](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), this example uses the well-known, small and self-contained [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset:

```{code-cell}
:id: aIwDAfS6PtFh
Expand All @@ -47,13 +47,13 @@ print(f"{images_test.shape=}")

+++ {"id": "2_Q16JRyrW7V"}

The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html#loading-the-data) in the previous tutorial.
The dataset comprises 1800 images of hand-written digits, each represented by an `8x8` pixel grid, and their corresponding labels. For visualization of this data, refer to [loading the data](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html#loading-the-data) in the previous tutorial.

+++ {"id": "Z9TPYqipPyBp"}

## Defining the VAE with Flax

[Previously](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:
[Previously](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), we learned how to use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network trained for classification with an architecture that looked roughly like this:

```{code-cell}
:id: HNlg-ydpr5yH
Expand Down
2 changes: 1 addition & 1 deletion docs/source/getting_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ After working through this content, you may wish to visit the [JAX documentation
```{toctree}
:maxdepth: 1

getting_started_with_jax_for_AI
neural_net_basics
digits_vae
digits_diffusion_model
```
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
"source": [
"# Part 1: JAX neural net basics\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)"
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/neural_net_basics.ipynb)"
]
},
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ kernelspec:

# Part 1: JAX neural net basics

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/neural_net_basics.ipynb)

+++ {"id": "z7sAr0sderhh"}

Expand Down