Sharded data comes back to GPU_0 after pmap #27290
Unanswered
hyunwoooh5
asked this question in
Q&A
Replies: 1 comment
-
I would suggest using Some docs for how to do that: https://docs.jax.dev/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html https://docs.jax.dev/en/latest/notebooks/explicit-sharding.html |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi. I'm trying to implement data parallelism using jax and I found that all the training set assembles to GPU_0 after pmap.
This is the following code:
Data is sharded when I run create_sharded_batch(), but after training, they are all in the same basket.
Question is
Beta Was this translation helpful? Give feedback.
All reactions