Replies: 1 comment
-
I guess you just need to use jax.device_put |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
How would one combine results on jit-ted computations across different devices?
I have a use case where I would like to run some part of the computation on CPU and transfer the result to the GPU to run the remaining compute.
Also in general, not just a CPU, lets say I have a HW accelerator backend that performs some compute. How do I pass that result to an xPU for other compute?
The below code errors saying: "Received incompatible devices for jitted computation. Got argument x of combine with shape float32[] and device ids [0] on platform GPU and argument y of combine with shape float32[] and device ids [0] on platform CPU"
Beta Was this translation helpful? Give feedback.
All reactions