Warning
This is an early experimental package. Feedback wanted!
A Python package to provide an easy to use MPI-backend for JAX sharding, built on top of MPIWrapper. No special operations, 100% native JAX.
# Using uv (recommended)
uv add git+https://github.com/mpi4jax/mpibackend4jax
# Using pip
pip install git+https://github.com/mpi4jax/mpibackend4jax
Simply import the package before using JAX with MPI:
import mpibackend4jax as _mpi4jax # noqa: F401
import jax
print("Setup initialize", flush=True)
jax.distributed.initialize()
print(f"{jax.process_index()}/{jax.process_count()} :", jax.local_devices())
print(f"{jax.process_index()}/{jax.process_count()} :", jax.devices())
x = jax.numpy.ones(
(jax.device_count(),),
device=jax.sharding.NamedSharding(
jax.sharding.Mesh(jax.devices(), "i"), jax.sharding.PartitionSpec("i")
),
)
print(f"{jax.process_index()}/{jax.process_count()} :", x.sum())
Run with MPI:
mpirun -np 2 python examples/example.py
When you import mpibackend4jax
, it automatically:
- Sets
MPITRAMPOLINE_LIB
to point to the builtlibmpiwrapper.so
- Sets
JAX_CPU_COLLECTIVES_IMPLEMENTATION=mpi
- CMake (for building MPIWrapper)
- A working MPI implementation (e.g., OpenMPI, MPICH)
- JAX
Tested on macOS with MPICH.
You can check if MPITrampoline is properly configured:
import mpibackend4jax
if mpibackend4jax.is_configured():
print("MPITrampoline is properly configured!")
else:
print("MPITrampoline configuration failed.")
Special thanks to @inailuig (Clemens Giuliani) for adding MPI support in XLA, which makes this integration possible.