How does jax implement jax.lax.psum and jax.lax.all_to_all? #18018
Unanswered
anirudhitagi
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello all,
I am new to Jax and I am exploring collective communications and so I wish to understand how does Jax implement
jax.lax.psum
,jax.lax.all_gather
andjax.lax.all_to_all
?Is it implemented as Ring, Binary tree, halving-doubling or something else?
And does this implementation depend on the hardware it is being deployed on? For context, I am experimenting with TPU v-3s.
Any documentation or explanation would be most appreciated !
Thanks and Regards,
AI
Beta Was this translation helpful? Give feedback.
All reactions