Pmap always slower than sharding? #18188
Unanswered
IrishWhiskey
asked this question in
Q&A
Replies: 1 comment
-
I would recommend using If you want to write manual collectives, then you can use shard_map: https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I used pmap multiple times to parallelize JAX models but it looks like it makes the code slower. I also noticed that sharding always makes it faster. What's the reason? I would expect the sharding to use pmap somehow.
I created an example to show this phenomena.
Consider the following piece of code:
After running this, I measured the execution time of
loss(params, data)
and I got~5ms
.I tried to parallelize the model by using pmap and I found the time to be
~40ms
:I then tried to shard the input data and I got execution time of
~2ms
.Does my code contain a bug or is pmap really slower than the sharding? Moreover, in this case the pmapped function seems to be much slower than the original one.
I ran these tests on an AWS SageMaker Notebook instance of type
ml.p3.8xlarge
. I used thetensorflow2_p310
kernel and installed jax by runningpip install -U "jax[cuda]==v0.4.18" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
.I attach the jupyter notebook I used.
Beta Was this translation helpful? Give feedback.
All reactions