Skip to content

If lax.fori_loop runs in parallel on GPU? #18054

Answered by jakevdp
Dong-Jiahuan asked this question in Q&A
Discussion options

You must be logged in to vote

Hi - thanks for the question! In general the answer is no: each step of a fori_loop depends on the output of the previous step, so there's no way for steps to be run in parallel (though the computations within each step will generally be run in parallel). Of course, it's possible to construct a fori_loop in which steps have no data dependence on previous steps, but I don't think the compiler has any logic for parallelizing such computations in the case of fori_loop (someone might correct me if I'm wrong on that!).

It would be better to express such an operation in a way that does not encode sequential dependence; good options are jax.vmap in cases where your full computation fits into mem…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@Dong-Jiahuan
Comment options

Answer selected by Dong-Jiahuan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants