Replies: 1 comment
-
You may be able to address this by specifying with jax.default_device(cpu):
result = jax.jacrev(func)(a, cpu) |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a function where a list of JAX arrays needs to be vertically stacked. The problem is that if I use
jax.numpy.vstack
to do the stacking, I run out of GPU memory. On the other hand, if I use ordinarynumpy
, the stacking takes place in CPU (where I have enough memory) but causes tracer error if I try to do any transformation.I tried wrapping the
onp.vstack
with apure_callback
plus acustom_jvp
rule. However, the result of the forward pass seems to end up in GPU memory anyway.Finally, I tried to put the list of arrays into cpu using
device_put()
and then usejax.numpy.vstack
. But I am not sure if that large matrix would always stay in CPU if I perform transformations on the function (like a reverse-mode autodiff).Is there a way to ensure that a particular computation will always take place in CPU, while being JAX transformable?
Beta Was this translation helpful? Give feedback.
All reactions