Skip to content

nv-legate/multimesh-jax-workflows

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MultiMesh for Jax Workflows

MultiMesh for JAX provides a framework for creating task contexts within jitted computations, allowing different subcomputations to be placed on different GPU submeshes. These task computations can be combined inside a global jit with data resharding across submeshes occurring automatically. MultiMesh therefore enables pipeline parallelism to be easily expressed. This repository provides a monorepo and associated workflows for creating a MaxText stack running with MultiMesh.

Prebuilt Images

Prebuilt images are published to the MultiMesh Github container registry. The most recent release can be pulled:

$ docker pull ghcr.io/nv-legate/multimesh-jax:v0.2

Requirements

Images are built from a CUDA 12.8 base image on Ubuntu 22. For system compatibility, refer to the CUDA toolkit documentation.

Docker Builds

Instructions for building images can be found here. To build the image, numerous submodules need to be downloaded and, in some cases, patched. To do so, run the ./bootstrap.sh script in the top folder. We recommend using the build driver script. For a full list of options, one can run build.py --help.

The most common build workflow will be:

multimesh-jax-workflows$ ./bootstrap.sh
multimesh-jax-workflows$ cd docker
multimesh-jax-workflows/docker$ ./build.py --tag <TAG> --upload --repo <REPO>

which builds an image named <REPO>:<TAG> and uploads it, assuming that <REPO> points to a valid container registry.

MaxText Configs

The main framework integrated with MultiMesh for Jax is MaxText. A default Docker build will produce an image with a MaxText installation. Please see the README for instructions on running MaxText with the provided helper scripts.

Open as Devcontainer

The top-level multimesh-jax.code-workspace can be opened in vscode, which should prompt to Reopen in container. Select this option. This will load vscode in a new devcontainer with all base dependencies installed. Startup scripts will then configure all builds and execute an initial build of the environment. The startup scripts point Bazel and CMake to build caches on your local system.

  • The initial base image download may take a long time on the first download
  • The startup scripts may take a few minutes to up to an hour depending on how much of the build is available in the build cache.

Once the workspace is open, it now contains a complete vscode development environment.

Remote cache

It is highly recommended to set up a remote bazel cache that is shared across containers. There may be an initial warmup in the first devcontainer, but subsequent builds should be quick. It should be sufficient to run:

multimesh-jax-workflows/docker $ ./start-cache.sh

Driver script for MaxText

A script for running MaxText is included to simplify the process of tuning parameters. A full list of options can be required by running:

multimesh-jax-workflows $ maxtext/run.py --help

There are several options for configuring the parallelism:

  • --dp <N>: The degree of data parallelism
  • --tp <N>: The degree of tensor parallelism
  • --fsdp <N>: The amount of fully-sharded data parallelism
  • --pp <N>: The amount of pipeline-parallelism

Currently, it is recommended to run in GPU/process mode. In this case, parameters should be:

  • --nodes <N>: The total number of processes (GPUs with proc/GPU). The product of DP x FSDP x PP x TP should match this number.
  • --gpus 1: A single GPU per process.
  • --cpus 1: A companion host for each GPU

For configuring pipeline parallelism, the most important parameters are:

  • --batch-size <N>: The global batch size. Most often, this will be 2x or 4x the total no. of GPUs
  • --microbatch-size <N>: The microbatch size per tensor-parallel domain. Each domain will compute on a batch of shape (Microbatch Size, Sequence Length). For a batch size of 1024, data-parallelism 8, and microbatch size 4 there will be a total of 32 microbatches.

Running smoke tests locally for MaxText

To run an example job in the container locally, example scripts are included in the repo.

MaxText with 2 GPUs

A script for running a small job with TP=2 is included. The container can be launched from the top-level directory as:

docker run \
  --entrypoint /opt/entrypoint.sh \
  --mount type=bind,source=$(pwd)/docker/workspace/maxtext-scripts,target=/workspace \
  -w /workspace \
  --gpus 2 \
  ghcr.io/nv-legate/multimesh-jax:v0.2 \
  ./validate-gpu.sh

MaxText with 8 CPUs

A script for running a small job with DP=2, PP=2, TP=2 is included. To launch the job:

docker run \
  --entrypoint /opt/entrypoint.sh \
  --mount type=bind,source=$(pwd)/docker/workspace/maxtext-scripts,target=/workspace \
  -w /workspace \
  ghcr.io/nv-legate/multimesh-jax:v0.2 \
  ./validate-cpu.sh

Slurm scripts

Scripts are included that are intended to be used inside a Slurm job, e.g.

srun -N <N> ... run-dgx-h100.sh

Scripts are included for:

Each folder contains a template Slurm batch script for launching containers.

About

No description or website provided.

Topics

Resources

License

Contributing

Stars

Watchers

Forks

Packages

No packages published

Contributors 5