Implementation plans for pallas.dynamic_slice
and scatter_reduce
ops
#25281
Unanswered
olivier-peltre
asked this question in
Q&A
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
When trying to execute a pallas kernel that calls
pl.dslice(start, size)
on any accelerator (GPU/TPU) I getNotImplementedError
(jaxlib == 0.4.34)pl.dslice
to come soon?Use case
scatter_reduce
ops or any suggestions to go forward?I noticed that
scatter_add
scales very bad, and never managed to haveindices_are_sorted=True
to produce a significant difference (in previous attempts, I think I gotindices_are_sorted=False
in the compiled jaxpr even when passing it as keyword).I will now try comparing with torch + rusty1s/pytorch_scatter to get an idea of the gains I could possibly hope for.
N.B. I'm looking for an efficient way to aggregate values based on a static index array, though I understand there are many constraints that may prevent very efficient dynamic scatter-reduce ops in XLA.
MW Example
Beta Was this translation helpful? Give feedback.
All reactions