Skip to content

[WIP] Add "The stack" section to left nav #212

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ myst-nb
myst-parser[linkify]
sphinx-book-theme
sphinx-copybutton
sphinx-design

# Packages required for notebook execution
matplotlib
Expand Down
1 change: 1 addition & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
extensions = [
'myst_nb',
'sphinx_copybutton',
'sphinx_design',
]

templates_path = ['_templates']
Expand Down
14 changes: 14 additions & 0 deletions docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,20 @@ JAX AI Stack
install
getting_started

.. toctree::
:hidden:
:caption: The stack
:maxdepth: 1

stack_overview
stack_jax
stack_flax
stack_optax
stack_orbax_checkpoint
stack_orbax_export
stack_grain
stack_chex

.. toctree::
:hidden:
:caption: Tutorials
Expand Down
1 change: 1 addition & 0 deletions docs/source/stack_chex.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Chex: test utilities
24 changes: 24 additions & 0 deletions docs/source/stack_flax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Flax NNX: neural nets

Flax NNX provides **neural net functionality** on top of JAX, such as a module
abstraction and pre-defined layers, via a **Pythonic object-oriented API**. NNX
allows you to write stateful model code that can still take advantage of JAX's
function transforms and other features.

NNX has native integration with [Optax](stack_optax).

Main Flax NNX site:
**[flax.readthedocs.io{material-regular}`open_in_new`](https://flax.readthedocs.io/)**

**If you'd like to learn more about NNX** beyond what's covered in the
[](getting_started) guide, we recommend starting with **[Flax
basics{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/nnx_basics.html)**.

The Flax NNX docs cover many other useful topics including:

* [Function
transforms{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/transforms.html)
* [Parallelism{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/flax_gspmd.html)
* [Performance
considerations{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/guides/performance.html)
* And much more!
1 change: 1 addition & 0 deletions docs/source/stack_grain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Grain: data loading
18 changes: 18 additions & 0 deletions docs/source/stack_jax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# JAX: array computing

JAX is the foundation of the JAX AI Stack! It provides **high-performance array
computing** functionality over accelerators via a simple **NumPy-like API and
function transformations**.

Main JAX site: **[jax.dev{material-regular}`open_in_new`](https://jax.dev)**

**If you'd like to learn more about JAX** beyond what's covered in the
[](getting_started) guide, we recommend starting with the **[JAX
tutorials{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/tutorials.html)**.

The JAX docs cover many other useful topics including:

* [Performance profiling{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/profiling.html)
* [Multi-host JAX programs{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/multi_process.html)
* [Custom GPU + TPU kernels with Pallas{material-regular}`open_in_new`](https://docs.jax.dev/en/latest/pallas/index.html)
* And much more!
12 changes: 12 additions & 0 deletions docs/source/stack_optax.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Optax: optimizers

Optax provides **gradient processing and optimization** functionality on top of
JAX, including optimizers and losses.

Main Optax site:
**[optax.readthedocs.io{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/index.html)**

**If you'd like to learn more about Optax** beyond what's covered in the
[](getting_started) guide, we recommend starting with the **[Optax getting
started{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/getting_started.html)**
guide.
1 change: 1 addition & 0 deletions docs/source/stack_orbax_checkpoint.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Orbax: checkpointing
1 change: 1 addition & 0 deletions docs/source/stack_orbax_export.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Orbax: model export
26 changes: 26 additions & 0 deletions docs/source/stack_overview.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Stack overview

The JAX AI Stack is comprised of the following packages:

* [JAX{material-regular}`open_in_new`](https://jax.dev): high-performance array
computing
* [Flax
NNX{material-regular}`open_in_new`](https://flax.readthedocs.io/en/latest/):
object-oriented neural nets
* [Optax{material-regular}`open_in_new`](https://optax.readthedocs.io/en/latest/index.html):
optimizers
* [Orbax{material-regular}`open_in_new`](https://orbax.readthedocs.io/en/latest/):
checkpointing and model export
* [Grain{material-regular}`open_in_new`](https://google-grain.readthedocs.io/en/latest/):
JAX-native data loading
* [Chex{material-regular}`open_in_new`](https://chex.readthedocs.io/en/latest/):
JAX test utilities

The `jax-ai-stack` metapackage installs compatible versions of all of these
libraries, as well as shared compatible versions of shared dependencies.

In addition, there is an optional `jax-ai-stack[tfds]` installation that
includes [TensorFlow
Datasets{material-regular}`open_in_new`](https://www.tensorflow.org/datasets),
for those who wish to use TFDS for data loading. This includes a compatible
version of TensorFlow as well.