-
Hi, I tested the official tutorial for running distributed JAX with multiple GPUs. I am using slurm, so I assume jax.distributed.initialize takes no args.
it works fine with an output of
However, the same code seems to not work for multiple CPUs across nodes.
Does anyone knows if this is not implemented feature or if I am using it wrong |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 7 replies
-
This is not implemented for CPU at the moment. However, you may be able to use https://github.com/mpi4jax/mpi4jax as a workaround. |
Beta Was this translation helpful? Give feedback.
-
I am working on integrating JAX with Kubeflow training operator to run distributed training and fine-tuning jobs on Kubernetes in CPU env. https://www.kubeflow.org/events/gsoc-2024/#project-5-support-distributed-jax-for-training-operator |
Beta Was this translation helpful? Give feedback.
-
Thank you @PhilipVinc! I running into the following error while running the example:
Can you please help me with how it can be resolved? |
Beta Was this translation helpful? Give feedback.
This is not implemented for CPU at the moment. However, you may be able to use https://github.com/mpi4jax/mpi4jax as a workaround.