Generating a array of random choices on the CPU, not GPU #21974
-
I want to generate an array of choice, using the As I understand it, I can place arrays, once they are created, on my device of choice (for example the CPU) with the following code (see here)
Or I can force a function to be executed on the CPU like this (see here)
The problem is that the function I want to use is Option 1 above is not viable, since just the creation of the device on the GPU exhausts the GPU memory, so I can't create it first, then move it to where I want. Option 2 doesn't work, since when jitting the function I encounter a
I also can't see an option for the inputs to One final detail, I'm not passing arrays as arguments to the function, I'm only passing integers. How can I call this function so that it creates the array of choices on the CPU rather than the GPU, when my default backend is GPU? |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 1 reply
-
Update - My current solution is to initialise the array in batches, as follows
And this results in the overall large array being placed on the CPU. But, there feels like there should be a more elegant solution to this problem? |
Beta Was this translation helpful? Give feedback.
-
The reason for that cpu_func = jax.jit(random.choice, static_argnums=(1, 2), device=device_cpu)
cpu_func(key, max_val, shape) Hope this helps! |
Beta Was this translation helpful? Give feedback.
The reason for that
ConcretizationTypeError
is that some of the arguments torandom.choice
must be static. You should be able to get your proposedjit
version to work using something like:Hope this helps!