diff --git a/docs/source/conf.py b/docs/source/conf.py index aad2b2f..bff261f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -37,8 +37,8 @@ # Theme-specific options # https://sphinx-book-theme.readthedocs.io/en/stable/reference.html html_theme_options = { - 'show_navbar_depth': 2, - 'show_toc_level': 2, + 'show_navbar_depth': 1, + 'show_toc_level': 1, 'repository_url': 'https://github.com/jax-ml/jax-ai-stack', 'path_to_docs': 'docs/source/', 'use_repository_button': True, diff --git a/docs/source/digits_diffusion_model.ipynb b/docs/source/digits_diffusion_model.ipynb index cbd45da..9c4c2ca 100644 --- a/docs/source/digits_diffusion_model.ipynb +++ b/docs/source/digits_diffusion_model.ipynb @@ -6,7 +6,7 @@ "id": "Kzqlx7fpXRnJ" }, "source": [ - "# Train a diffusion model for image generation with JAX for AI\n", + "# Part 3: Train a diffusion model for image generation\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)\n", "\n", diff --git a/docs/source/digits_diffusion_model.md b/docs/source/digits_diffusion_model.md index 635d8e1..06e1b60 100644 --- a/docs/source/digits_diffusion_model.md +++ b/docs/source/digits_diffusion_model.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "Kzqlx7fpXRnJ"} -# Train a diffusion model for image generation with JAX for AI +# Part 3: Train a diffusion model for image generation [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb) diff --git a/docs/source/digits_vae.ipynb b/docs/source/digits_vae.ipynb index 190927a..7c431ab 100644 --- a/docs/source/digits_vae.ipynb +++ b/docs/source/digits_vae.ipynb @@ -6,7 +6,7 @@ "id": "47OmRSTR1dJU" }, "source": [ - "# Variational autoencoder (VAE) and debugging in JAX\n", + "# Part 2: Debug a variational autoencoder (VAE)\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb)\n", "\n", diff --git a/docs/source/digits_vae.md b/docs/source/digits_vae.md index bf2c1a4..05f14c3 100644 --- a/docs/source/digits_vae.md +++ b/docs/source/digits_vae.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "47OmRSTR1dJU"} -# Variational autoencoder (VAE) and debugging in JAX +# Part 2: Debug a variational autoencoder (VAE) [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_vae.ipynb) diff --git a/docs/source/examples.md b/docs/source/examples.md deleted file mode 100644 index 00a18cc..0000000 --- a/docs/source/examples.md +++ /dev/null @@ -1,16 +0,0 @@ -# Example applications - -The following pages provide examples of common applications of the JAX AI stack: - -```{toctree} -:maxdepth: 1 - -JAX_for_LLM_pretraining -JAX_basic_text_classification -JAX_transformer_text_classification -JAX_machine_translation -JAX_examples_image_segmentation -JAX_image_captioning -JAX_Vision_transformer -JAX_time_series_classification -``` diff --git a/docs/source/getting_started.md b/docs/source/getting_started.md new file mode 100644 index 0000000..4e5d83e --- /dev/null +++ b/docs/source/getting_started.md @@ -0,0 +1,26 @@ +# Getting started with JAX for ML + +[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond. + +## Who is this tutorial for? + +This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models. + +## What does this tutorial cover? + +JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including: + +- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX. +- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions. + +After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts. + +## Let's get started! + +```{toctree} +:maxdepth: 1 + +getting_started_with_jax_for_AI +digits_vae +digits_diffusion_model +``` diff --git a/docs/source/getting_started_with_jax_for_AI.ipynb b/docs/source/getting_started_with_jax_for_AI.ipynb index de49209..d221fed 100644 --- a/docs/source/getting_started_with_jax_for_AI.ipynb +++ b/docs/source/getting_started_with_jax_for_AI.ipynb @@ -6,38 +6,9 @@ "id": "AEQPh3NtawWA" }, "source": [ - "# Getting started with JAX for AI\n", + "# Part 1: JAX neural net basics\n", "\n", - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)\n", - "\n", - "[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "lN1DEDeMel9r" - }, - "source": [ - "## Who is this tutorial for?\n", - "\n", - "This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1Y92oUSGeoRz" - }, - "source": [ - "## What does this tutorial cover?\n", - "\n", - "JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:\n", - "\n", - "- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.\n", - "- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.\n", - "\n", - "After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts." + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb)" ] }, { diff --git a/docs/source/getting_started_with_jax_for_AI.md b/docs/source/getting_started_with_jax_for_AI.md index 5d2b7f9..04cca88 100644 --- a/docs/source/getting_started_with_jax_for_AI.md +++ b/docs/source/getting_started_with_jax_for_AI.md @@ -13,29 +13,10 @@ kernelspec: +++ {"id": "AEQPh3NtawWA"} -# Getting started with JAX for AI +# Part 1: JAX neural net basics [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/getting_started_with_jax_for_AI.ipynb) -[JAX](http://jax.readthedocs.io) is a Python package for accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google and beyond. - -+++ {"id": "lN1DEDeMel9r"} - -## Who is this tutorial for? - -This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models. - -+++ {"id": "1Y92oUSGeoRz"} - -## What does this tutorial cover? - -JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including: - -- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX. -- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions. - -After working through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts. - +++ {"id": "z7sAr0sderhh"} ## Example: A simple neural network with Flax diff --git a/docs/source/index.rst b/docs/source/index.rst index 2ab31e5..4160827 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -1,4 +1,4 @@ -Jax AI Stack +JAX AI Stack ============ .. raw:: html @@ -6,9 +6,38 @@ Jax AI Stack .. toctree:: :hidden: - :maxdepth: 2 + :caption: Getting started + :maxdepth: 1 install - tutorials - examples + getting_started + +.. toctree:: + :hidden: + :caption: Tutorials + :maxdepth: 1 + + JAX_visualizing_models_metrics + data_loaders + pytorch_users + +.. toctree:: + :hidden: + :caption: Example applications + :maxdepth: 1 + + JAX_for_LLM_pretraining + JAX_basic_text_classification + JAX_transformer_text_classification + JAX_machine_translation + JAX_examples_image_segmentation + JAX_image_captioning + JAX_Vision_transformer + JAX_time_series_classification + +.. toctree:: + :hidden: + :caption: Developer resources + :maxdepth: 1 + contributing diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md deleted file mode 100644 index 3e984c9..0000000 --- a/docs/source/tutorials.md +++ /dev/null @@ -1,22 +0,0 @@ -# Tutorials - -The following tutorials are meant as an introduction to the full stack: - -```{toctree} -:maxdepth: 1 - -getting_started_with_jax_for_AI -digits_vae -digits_diffusion_model -JAX_visualizing_models_metrics -data_loaders -pytorch_users -``` - -## Further references - -Once you've gone through this content, you can refer to package-specific -documentation for resources that go into more depth on various topics: - -- [JAX tutorials](https://jax.readthedocs.io/en/latest/tutorials.html) -- [FLAX user guides](https://flax.readthedocs.io/en/latest/guides/index.html)