If lax.fori_loop
runs in parallel on GPU?
#18054
-
Hello, I want to figure out if lax.fori_loop runs in parallel on GPU? In my case, I wrote a custom operator using CUDA and it's much faster than function (on GPU) which do the same logic and is constructed with |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Hi - thanks for the question! In general the answer is no: each step of a It would be better to express such an operation in a way that does not encode sequential dependence; good options are |
Beta Was this translation helpful? Give feedback.
Hi - thanks for the question! In general the answer is no: each step of a
fori_loop
depends on the output of the previous step, so there's no way for steps to be run in parallel (though the computations within each step will generally be run in parallel). Of course, it's possible to construct afori_loop
in which steps have no data dependence on previous steps, but I don't think the compiler has any logic for parallelizing such computations in the case offori_loop
(someone might correct me if I'm wrong on that!).It would be better to express such an operation in a way that does not encode sequential dependence; good options are
jax.vmap
in cases where your full computation fits into mem…