Jax error #14221
-
Dear Jax developers, I have trained a network with pytorch and I want to get prediction but calling it from jax. If I do that I get: jax._src.errors.TracerArrayConversionError: The numpy.ndarray conversion method array() was called on the JAX Tracer object I have tried to solve the issue with "pure_callback" but the problem persists. This is my code: def log_likelihood_jax(x):
x = np.array(x)
if len(x.shape) == 1:
x = x.reshape(1, -1)
x = self._normalise(x, self.x_mean, self.x_stdev)
#Call the network with a np type input
y = self.model.predict(np.array(x))
y = self._unnormalise(y, self.y_mean, self.y_stdev)
return -jnp.array(y)
p = list
x = jnp.array(p)
result_shape = jax.core.ShapedArray(x.shape, x.dtype)
loglike = jax.pure_callback(self.log_likelihood_jax(x), result_shape, x) Error is in Any ideas ? Thanks a lot !!! With Kind Regards, |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I think the problem lies here: loglike = jax.pure_callback(self.log_likelihood_jax(x), result_shape, x)
# ^ It should have been: loglike = jax.pure_callback(self.log_likelihood_jax, result_shape, x)
# ^---------------------^
# first argument of jax.pure_callback is a callable But this is all guess work. It would be better if you could provide a minimal runnable example, and wrap the code in a markdown code block like the following
|
Beta Was this translation helpful? Give feedback.
-
Hi,
Thanks for the feedback.
It works now !!
Rbt
… On 31 Jan 2023, at 12:38, soraros ***@***.***> wrote:
I think the problem lies here:
loglike = jax.pure_callback(self.log_likelihood_jax(x), result_shape, x)
# ^
and should have been
loglike = jax.pure_callback(self.log_likelihood_jax, result_shape, x)
# ^---------------------^
# first argument of jax.pure_callback is a callable
But this is all guess work. It would be better if you could provide a minimal runnable example, and put the code in a markdown code block like the following
```python
def f(x):
...
```
—
Reply to this email directly, view it on GitHub <#14221 (comment)>, or unsubscribe <https://github.com/notifications/unsubscribe-auth/ABWDLYTW3CKOQH64BUVHQDDWVD2U3ANCNFSM6AAAAAAUMFEHMA>.
You are receiving this because you authored the thread.
|
Beta Was this translation helpful? Give feedback.
I think the problem lies here:
It should have been:
But this is all guess work. It would be better if you could provide a minimal runnable example, and wrap the code in a markdown code block like the following