Understanding asynchronous dispatch and running multiple functions in parallel #26312
Unanswered
markus7800
asked this question in
Q&A
Replies: 1 comment 2 replies
-
On GPU backend we have to wait for the completion of every loop iteration to copy back to host the predicate to decide if we should execute the next loop iteration, so effectively only computations after the loop body are "asynchronous". Copying predicate back basically forces XLA to sync CUDA stream with host. We do have an internal bug (b/382117736, sorry link for googlers only) that fixes this by launching pjrt/xla operations in a dedicated thread pool, but not sure when it will be fixed. |
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello,
as explained in https://jax.readthedocs.io/en/latest/async_dispatch.html JAX does not wait for the operation to complete before returning control to the Python program.
I tested this with following test function
and for
test1(10_000)
I got the expected output (first print statement is executed immediately because of asyncronous dispatch):Now consider following test setup, where I hoped to asynchronously dispatch the
scan
operation:When calling with a computationally expensive step and small number of iterations,
test2(10_000, 2)
, we haveWhen calling with an inexpensive step and large number of iterations,
test2(100, 10_000)
, we haveSo seemingly, the
scan
operation was only asynchronously dispatched in the first case and sequentially executed in the second case.I know that I can vectorise the computation with
and executing
test3(100, 10_000)
now outputsSo I have confirmed that with vectorising, my GPU is able to perform both
scan
operations in parallel in the same time.My question is: how would I parallelise two scan operations that use different step functions, where
vmap
in this form is not available.This question is related to #673, #25630, #20916, and #23306.
In the answers, often a
switch
operation is recommended.But
vmap
overswitch
makes it aselect
where all branches are executed.In my case, the step functions are expensive and I want to apply them to only a subset of the data.
On CPU,
pmap
would probably work to distribute the step functions to different cores, but I want to run the operations on one GPU, wherepmap
(orshard_map
) does not help, I think.I would appreciate any advice for my situation or a confirmation that what I want to achieve is not possible with JAX.
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions