MultiMesh for JAX provides a framework for creating task
contexts within jitted computations,
allowing different subcomputations to be placed on different GPU submeshes. These
task
computations can be combined inside a global jit
with data resharding across submeshes
occurring automatically. MultiMesh therefore enables pipeline parallelism to be easily expressed.
Standard Jax SPMD sharding idioms can be used within each task
,
enabling full N-dimensional parallelism.
This repository contains a PjRt plugin and Python helper APIs.
The easiest way to get started is by using the prebuilt containers
$ docker pull ghcr.io/nv-legate/multimesh-jax:v0.2
Containers can be also be built using MultiMesh for Jax workflows.
User documentation including API reference , Jupyter tutorials, and architecture overview can be found on the docs page.
The recommended way to run the examples is through Docker. To launch a Jupyter notebook in the container built using the build workflows for running on CPU:
docker run \
--mount type=bind,source=$(pwd)/docs/notebooks,target=/opt/notebooks \
-w /opt/notebooks \
-p 8675:8675 \
ghcr.io/nv-legate/multimesh-jax:v0.2 \
jupyter notebook --allow-root --ip 0.0.0.0 --port=8675
The notebook will then be available at the URL shown,
which is usually http://127.0.0.1:8675/...
or http://localhost:8675/...
.
If GPUs are available, then docker can be launched as:
docker run \
--mount type=bind,source=$(pwd)/docs/notebooks,target=/opt/notebooks \
-w /opt/notebooks \
-p 8675:8675 \
--gpus <N> \
ghcr.io/nv-legate/multimesh-jax:v0.2 \
jupyter notebook --allow-root --ip 0.0.0.0 --port=8675
where <N>
is the number of GPUs.
The prebuilt container is built from a CUDA 12.8 base image on Ubuntu 22. For system compatibility, refer to the CUDA toolkit documentation.
The main framework integrated with MultiMesh for Jax is MaxText.
Running and configuring MaxText can be challenging given the number of
options for specifying the models. To aid in running transformer models,
a helper script has been added with a basic set of options
for configuring parallelism in the MultiMesh for Jax workflows
run.py --help
will give the full set of options.
- When running with MPMD sharding, Orbax checkpoints may not work correctly since some processes will not have addressable shards.
- For running with external libraries like TransformerEngine, parallelism is only valid when running with process/GPU rather than process/node.
- The plugin is experimental and may not work correctly if used outside of the tutorials or documented MaxText examples.
In the future, a standard JAX and Jaxlib installation should be compatible with MultiMesh. For now, patches will have to be applied to Jax and Jaxlib to work with the MultiMesh for JAX client.