Replies: 3 comments 8 replies
-
Hey @igorvere my understanding is that JAX does not currently have user friendly support for multi-host multi-gpu. There is lightly tested support in the runtime for distributed jobs that use NCCL for multi-host collective operations, but there are no Python APIs to configure this at the moment. |
Beta Was this translation helpful? Give feedback.
-
For anyone interested, MPI4Jax allows you to run computations on multiple gpu hosts. |
Beta Was this translation helpful? Give feedback.
-
we finally used nccl directly (thanks for unsafe pointer interface) to do distributed ops between gpu hosts I can see there is promising pull request #8364 which perhaps can make pmap available for gpu multihost setup Although for single host multigpu machine I see that collective operations are prohibitively slow vs direct nccl calls |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello, can anybody share sample code how to start jax on multiple GPU nodes. So pmap can be used for collective operations (says sum up gradients from each node)
Thanks
Igor
Beta Was this translation helpful? Give feedback.
All reactions