Skip to content

Commit 8bd9c43

Browse files
authored
Reorganize ToC and getting started flow (#209)
* Add new top-level headers to ToC that aren't pages via `:caption:` * Add a new "Getting started" section separate from "Tutorials" * Organize the first three getting started tutorials into a 3-part flow
1 parent df96647 commit 8bd9c43

11 files changed

+68
-99
lines changed

docs/source/conf.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,8 @@
3737
# Theme-specific options
3838
# https://sphinx-book-theme.readthedocs.io/en/stable/reference.html
3939
html_theme_options = {
40-
'show_navbar_depth': 2,
41-
'show_toc_level': 2,
40+
'show_navbar_depth': 1,
41+
'show_toc_level': 1,
4242
'repository_url': 'https://github.com/jax-ml/jax-ai-stack',
4343
'path_to_docs': 'docs/source/',
4444
'use_repository_button': True,

docs/source/digits_diffusion_model.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"id": "Kzqlx7fpXRnJ"
77
},
88
"source": [
9-
"# Train a diffusion model for image generation with JAX for AI\n",
9+
"# Part 3: Train a diffusion model for image generation\n",
1010
"\n",
1111
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n",
1212
"\n",

docs/source/digits_diffusion_model.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313

1414
+++ {"id": "Kzqlx7fpXRnJ"}
1515

16-
# Train a diffusion model for image generation with JAX for AI
16+
# Part 3: Train a diffusion model for image generation
1717

1818
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)
1919

docs/source/digits_vae.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
"id": "47OmRSTR1dJU"
77
},
88
"source": [
9-
"# Variational autoencoder (VAE) and debugging in JAX\n",
9+
"# Part 2: Debug a variational autoencoder (VAE)\n",
1010
"\n",
1111
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)\n",
1212
"\n",

docs/source/digits_vae.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ kernelspec:
1313

1414
+++ {"id": "47OmRSTR1dJU"}
1515

16-
# Variational autoencoder (VAE) and debugging in JAX
16+
# Part 2: Debug a variational autoencoder (VAE)
1717

1818
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)
1919

docs/source/examples.md

Lines changed: 0 additions & 16 deletions
This file was deleted.

docs/source/getting_started.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
# Getting started with JAX for ML
2+
3+
[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.
4+
5+
## Who is this tutorial for?
6+
7+
This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.
8+
9+
## What does this tutorial cover?
10+
11+
JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:
12+
13+
- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
14+
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.
15+
16+
After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
17+
18+
## Let's get started!
19+
20+
```{toctree}
21+
:maxdepth: 1
22+
23+
getting_started_with_jax_for_AI
24+
digits_vae
25+
digits_diffusion_model
26+
```

docs/source/getting_started_with_jax_for_AI.ipynb

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,9 @@
66
"id": "AEQPh3NtawWA"
77
},
88
"source": [
9-
"# Getting started with JAX for AI\n",
9+
"# Part 1: JAX neural net basics\n",
1010
"\n",
11-
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)\n",
12-
"\n",
13-
"[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond."
14-
]
15-
},
16-
{
17-
"cell_type": "markdown",
18-
"metadata": {
19-
"id": "lN1DEDeMel9r"
20-
},
21-
"source": [
22-
"## Who is this tutorial for?\n",
23-
"\n",
24-
"This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models."
25-
]
26-
},
27-
{
28-
"cell_type": "markdown",
29-
"metadata": {
30-
"id": "1Y92oUSGeoRz"
31-
},
32-
"source": [
33-
"## What does this tutorial cover?\n",
34-
"\n",
35-
"JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:\n",
36-
"\n",
37-
"- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.\n",
38-
"- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.\n",
39-
"\n",
40-
"After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts."
11+
"[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)"
4112
]
4213
},
4314
{

docs/source/getting_started_with_jax_for_AI.md

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,29 +13,10 @@ kernelspec:
1313

1414
+++ {"id": "AEQPh3NtawWA"}
1515

16-
# Getting started with JAX for AI
16+
# Part 1: JAX neural net basics
1717

1818
[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)
1919

20-
[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond.
21-
22-
+++ {"id": "lN1DEDeMel9r"}
23-
24-
## Who is this tutorial for?
25-
26-
This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models.
27-
28-
+++ {"id": "1Y92oUSGeoRz"}
29-
30-
## What does this tutorial cover?
31-
32-
JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:
33-
34-
- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.
35-
- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.
36-
37-
After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts.
38-
3920
+++ {"id": "z7sAr0sderhh"}
4021

4122
## Example: A simple neural network with Flax

docs/source/index.rst

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,43 @@
1-
Jax AI Stack
1+
JAX AI Stack
22
============
33

44
.. raw:: html
55
:file: index.html
66

77
.. toctree::
88
:hidden:
9-
:maxdepth: 2
9+
:caption: Getting started
10+
:maxdepth: 1
1011

1112
install
12-
tutorials
13-
examples
13+
getting_started
14+
15+
.. toctree::
16+
:hidden:
17+
:caption: Tutorials
18+
:maxdepth: 1
19+
20+
JAX_visualizing_models_metrics
21+
data_loaders
22+
pytorch_users
23+
24+
.. toctree::
25+
:hidden:
26+
:caption: Example applications
27+
:maxdepth: 1
28+
29+
JAX_for_LLM_pretraining
30+
JAX_basic_text_classification
31+
JAX_transformer_text_classification
32+
JAX_machine_translation
33+
JAX_examples_image_segmentation
34+
JAX_image_captioning
35+
JAX_Vision_transformer
36+
JAX_time_series_classification
37+
38+
.. toctree::
39+
:hidden:
40+
:caption: Developer resources
41+
:maxdepth: 1
42+
1443
contributing

docs/source/tutorials.md

Lines changed: 0 additions & 22 deletions
This file was deleted.

0 commit comments

Comments
 (0)