Skip to content

Jax error #14221

Answered by soraros
rruizdeaustri asked this question in General
Discussion options

You must be logged in to vote

I think the problem lies here:

loglike = jax.pure_callback(self.log_likelihood_jax(x), result_shape, x)
#                                                   ^

It should have been:

loglike = jax.pure_callback(self.log_likelihood_jax, result_shape, x)
#                           ^---------------------^
# first argument of jax.pure_callback is a callable

But this is all guess work. It would be better if you could provide a minimal runnable example, and wrap the code in a markdown code block like the following

```python
def f(x):
  ...
```

Replies: 2 comments

Comment options

You must be logged in to vote
0 replies
Answer selected by froystig
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants