Skip to content

mpi4jax/mpibackend4jax

Repository files navigation

mpibackend4jax

Warning

This is an early experimental package. Feedback wanted!

A Python package to provide an easy to use MPI-backend for JAX sharding, built on top of MPIWrapper. No special operations, 100% native JAX.

Installation

# Using uv (recommended)
uv add git+https://github.com/mpi4jax/mpibackend4jax

# Using pip
pip install git+https://github.com/mpi4jax/mpibackend4jax

Usage

Simply import the package before using JAX with MPI:

import mpibackend4jax as _mpi4jax  # noqa: F401
import jax

print("Setup initialize", flush=True)
jax.distributed.initialize()
print(f"{jax.process_index()}/{jax.process_count()} :", jax.local_devices())
print(f"{jax.process_index()}/{jax.process_count()} :", jax.devices())

x = jax.numpy.ones(
    (jax.device_count(),),
    device=jax.sharding.NamedSharding(
        jax.sharding.Mesh(jax.devices(), "i"), jax.sharding.PartitionSpec("i")
    ),
)

print(f"{jax.process_index()}/{jax.process_count()} :", x.sum())

Run with MPI:

mpirun -np 2 python examples/example.py

What it does

When you import mpibackend4jax, it automatically:

  1. Sets MPITRAMPOLINE_LIB to point to the built libmpiwrapper.so
  2. Sets JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi

Requirements

  • CMake (for building MPIWrapper)
  • A working MPI implementation (e.g., OpenMPI, MPICH)
  • JAX

Tested on macOS with MPICH.

Verification

You can check if MPITrampoline is properly configured:

import mpibackend4jax

if mpibackend4jax.is_configured():
    print("MPITrampoline is properly configured!")
else:
    print("MPITrampoline configuration failed.")

Acknowledgments

Special thanks to @inailuig (Clemens Giuliani) for adding MPI support in XLA, which makes this integration possible.

About

Easily use MPI as a backend for Jax native sharding

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages