Can an array be kept on a specific device inside a jitted function? #17040
-
The title says it all. I am trying to keep a large array on the CPU memory whilst performing computation on a TPU. I am willing to sacrifice time for host-to-device transfer but if the entire training function is compiled using JIT, is it possible to specify to jax to keep an array on cpu and move the necessary data to TPU every time it is required inside the compiled function? To add more context - all device transfer should happen within the compiled function itself and not in an outerloop that uses a compiled function. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Yes, this is possible with a new Memories API we are working on. It's not ready for public consumption yet but it will be soon. |
Beta Was this translation helpful? Give feedback.
Yes, this is possible with a new Memories API we are working on. It's not ready for public consumption yet but it will be soon.