MultiMesh for JAX provides a framework for creating task
contexts within jitted computations,
allowing different subcomputations to be placed on different GPU submeshes. These
task
computations can be combined inside a global jit
with data resharding across submeshes
occurring automatically. MultiMesh therefore enables pipeline parallelism to be easily expressed.
This repository provides a monorepo and associated workflows for creating a MaxText
stack running with MultiMesh.
Prebuilt images are published to the MultiMesh Github container registry. The most recent release can be pulled:
$ docker pull ghcr.io/nv-legate/multimesh-jax:v0.2
Images are built from a CUDA 12.8 base image on Ubuntu 22. For system compatibility, refer to the CUDA toolkit documentation.
Instructions for building images can be found here.
To build the image, numerous submodules need to be downloaded and, in some cases, patched.
To do so, run the ./bootstrap.sh
script in the top folder.
We recommend using the build driver script.
For a full list of options, one can run build.py --help
.
The most common build workflow will be:
multimesh-jax-workflows$ ./bootstrap.sh
multimesh-jax-workflows$ cd docker
multimesh-jax-workflows/docker$ ./build.py --tag <TAG> --upload --repo <REPO>
which builds an image named <REPO>:<TAG>
and uploads it,
assuming that <REPO>
points to a valid container registry.
The main framework integrated with MultiMesh for Jax is MaxText. A default Docker build will produce an image with a MaxText installation. Please see the README for instructions on running MaxText with the provided helper scripts.
The top-level multimesh-jax.code-workspace can be opened in vscode,
which should prompt to Reopen in container
. Select this option.
This will load vscode in a new devcontainer with all base dependencies installed.
Startup scripts will then configure all builds and execute an initial build
of the environment. The startup scripts point Bazel and CMake to build
caches on your local system.
- The initial base image download may take a long time on the first download
- The startup scripts may take a few minutes to up to an hour depending on how much of the build is available in the build cache.
Once the workspace is open, it now contains a complete vscode development environment.
It is highly recommended to set up a remote bazel cache that is shared across containers. There may be an initial warmup in the first devcontainer, but subsequent builds should be quick. It should be sufficient to run:
multimesh-jax-workflows/docker $ ./start-cache.sh
A script for running MaxText is included to simplify the process of tuning parameters. A full list of options can be required by running:
multimesh-jax-workflows $ maxtext/run.py --help
There are several options for configuring the parallelism:
--dp <N>
: The degree of data parallelism--tp <N>
: The degree of tensor parallelism--fsdp <N>
: The amount of fully-sharded data parallelism--pp <N>
: The amount of pipeline-parallelism
Currently, it is recommended to run in GPU/process mode. In this case, parameters should be:
--nodes <N>
: The total number of processes (GPUs with proc/GPU). The product of DP x FSDP x PP x TP should match this number.--gpus 1
: A single GPU per process.--cpus 1
: A companion host for each GPU
For configuring pipeline parallelism, the most important parameters are:
--batch-size <N>
: The global batch size. Most often, this will be 2x or 4x the total no. of GPUs--microbatch-size <N>
: The microbatch size per tensor-parallel domain. Each domain will compute on a batch of shape(Microbatch Size, Sequence Length)
. For a batch size of 1024, data-parallelism 8, and microbatch size 4 there will be a total of 32 microbatches.
To run an example job in the container locally, example scripts are included in the repo.
A script for running a small job with TP=2 is included. The container can be launched from the top-level directory as:
docker run \
--entrypoint /opt/entrypoint.sh \
--mount type=bind,source=$(pwd)/docker/workspace/maxtext-scripts,target=/workspace \
-w /workspace \
--gpus 2 \
ghcr.io/nv-legate/multimesh-jax:v0.2 \
./validate-gpu.sh
A script for running a small job with DP=2, PP=2, TP=2 is included. To launch the job:
docker run \
--entrypoint /opt/entrypoint.sh \
--mount type=bind,source=$(pwd)/docker/workspace/maxtext-scripts,target=/workspace \
-w /workspace \
ghcr.io/nv-legate/multimesh-jax:v0.2 \
./validate-cpu.sh
Scripts are included that are intended to be used inside a Slurm job, e.g.
srun -N <N> ... run-dgx-h100.sh
Scripts are included for:
Each folder contains a template Slurm batch script for launching containers.