Replies: 1 comment 2 replies
-
Do you currently know the best practices for implementing multi-host data parallelism with JAX? |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I want to use JAX in multi-host and multi-process environments. I have found the tutorial at "https://jax.readthedocs.io/en/latest/multi_process.html". However, it doesn't seem to explain how to load data on different hosts. Suppose I have 100 datasets and 10 processes with 10 devices, labeled as "data_001" to "data_100". I want the first process to load "data_001", "data_011", and so on. Then, I would like to train the model based on the dataset in each process and average the gradients at the end.
Beta Was this translation helpful? Give feedback.
All reactions