Skip to content

Reorganize ToC and getting started flow #209

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# Theme-specific options
# https://sphinx-book-theme.readthedocs.io/en/stable/reference.html
html_theme_options = {
'show_navbar_depth': 2,
'show_toc_level': 2,
'show_navbar_depth': 1,
'show_toc_level': 1,
'repository_url': 'https://github.com/jax-ml/jax-ai-stack',
'path_to_docs': 'docs/source/',
'use_repository_button': True,
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_diffusion_model.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "Kzqlx7fpXRnJ"
},
"source": [
"# Train a diffusion model for image generation with JAX for AI\n",
"# Part 3: Train a diffusion model for image generation\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ kernelspec:

+++ {"id": "Kzqlx7fpXRnJ"}

# Train a diffusion model for image generation with JAX for AI
# Part 3: Train a diffusion model for image generation

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)

Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_vae.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"id": "47OmRSTR1dJU"
},
"source": [
"# Variational autoencoder (VAE) and debugging in JAX\n",
"# Part 2: Debug a variational autoencoder (VAE)\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion docs/source/digits_vae.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ kernelspec:

+++ {"id": "47OmRSTR1dJU"}

# Variational autoencoder (VAE) and debugging in JAX
# Part 2: Debug a variational autoencoder (VAE)

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)

Expand Down
16 changes: 0 additions & 16 deletions docs/source/examples.md

This file was deleted.

26 changes: 26 additions & 0 deletions docs/source/getting_started.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Getting started with JAX for ML

[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.

## Who is this tutorial for?

This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.

## What does this tutorial cover?

JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:

- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.

After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.

## Let's get started!

```{toctree}
:maxdepth: 1

getting_started_with_jax_for_AI
digits_vae
digits_diffusion_model
```
33 changes: 2 additions & 31 deletions docs/source/getting_started_with_jax_for_AI.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,9 @@
"id": "AEQPh3NtawWA"
},
"source": [
"# Getting started with JAX for AI\n",
"# Part 1: JAX neural net basics\n",
"\n",
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)\n",
"\n",
"[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "lN1DEDeMel9r"
},
"source": [
"## Who is this tutorial for?\n",
"\n",
"This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "1Y92oUSGeoRz"
},
"source": [
"## What does this tutorial cover?\n",
"\n",
"JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:\n",
"\n",
"- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.\n",
"- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.\n",
"\n",
"After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts."
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)"
]
},
{
Expand Down
21 changes: 1 addition & 20 deletions docs/source/getting_started_with_jax_for_AI.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,10 @@ kernelspec:

+++ {"id": "AEQPh3NtawWA"}

# Getting started with JAX for AI
# Part 1: JAX neural net basics

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)

[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.

+++ {"id": "lN1DEDeMel9r"}

## Who is this tutorial for?

This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.

+++ {"id": "1Y92oUSGeoRz"}

## What does this tutorial cover?

JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:

- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.

After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.

+++ {"id": "z7sAr0sderhh"}

## Example: A simple neural network with Flax
Expand Down
37 changes: 33 additions & 4 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
@@ -1,14 +1,43 @@
Jax AI Stack
JAX AI Stack
============

.. raw:: html
:file: index.html

.. toctree::
:hidden:
:maxdepth: 2
:caption: Getting started
:maxdepth: 1

install
tutorials
examples
getting_started

.. toctree::
:hidden:
:caption: Tutorials
:maxdepth: 1

JAX_visualizing_models_metrics
data_loaders
pytorch_users

.. toctree::
:hidden:
:caption: Example applications
:maxdepth: 1

JAX_for_LLM_pretraining
JAX_basic_text_classification
JAX_transformer_text_classification
JAX_machine_translation
JAX_examples_image_segmentation
JAX_image_captioning
JAX_Vision_transformer
JAX_time_series_classification

.. toctree::
:hidden:
:caption: Developer resources
:maxdepth: 1

contributing
22 changes: 0 additions & 22 deletions docs/source/tutorials.md

This file was deleted.