Replies: 3 comments 1 reply
-
additional info: the function is called in the following manner:
I suspect that it might be something to do with the jax.value_and_grad as the function is called under this?? but I dont know exactly how should I solve this. |
Beta Was this translation helpful? Give feedback.
-
Edit: there is now a FAQ entry on this topic: https://jax.readthedocs.io/en/latest/faq.html#how-can-i-convert-a-jax-tracer-to-a-numpy-array You cannot convert a JAX tracer to a numpy array: a JAX tracer is an abstract value that represents all possible arrays of a given shape and dtype; a numpy array is a concrete value representing one particular array. It sounds like what you want to do is to call non-JAX code from within transformed JAX code. This is possible to do via Can you say more about why you want to convert these values to numpy arrays? Perhaps there's a way to keep your computation in JAX rather than calling non-JAX libraries? |
Beta Was this translation helpful? Give feedback.
-
Hi Jake, Thanks for your reply. I need the conversion because my One solution I can think of is to move the whole part of
out from the current function and define another new function like:
And this function is not decorated by the jit and called outside all decorated functions. But then it raises another issue which is I cannot jit or accerelate this function by any JAX decorator, am I right? >< |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I am struggling with this conversion of tracer to numpy array. I have a function that calculate a loss:
and I want to convert the product of this function to a numpy array via:
loss_for_prior = np.array(elementwise_loss)
However, it gives me an error of
What should I do in this case?
Beta Was this translation helpful? Give feedback.
All reactions