-
Hello, I started looking into jax today, and while reading the common gotchas notebook I ran into something that confused me. Towards the 14th line, non array inputs are discussed. We are given a bad example of giving a list to a jax function here
We are then told that a better way of doing this would be Is this a typo or is there something I missed? I assume the first example was supposed to be |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 2 replies
-
I think the point they are trying to make is that you shouldn't pass in lists and convert them to arrays inside of a function. |
Beta Was this translation helpful? Give feedback.
-
If The workaround is to convert lists to arrays before passing them to JAX functions, i.e. |
Beta Was this translation helpful? Give feedback.
I think the point they are trying to make is that you shouldn't pass in lists and convert them to arrays inside of a function.
You should read
jnp.sum(jnp.array(x))
as passing an arrayjnp.array(x)
to the functionjnp.sum
(for there is no need to wrap it in a lambda).