-
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 4 replies
-
Could you provide a runnable example (preferably in a single code block)? I don't see |
Beta Was this translation helpful? Give feedback.
-
You cannot use a simple print(jax.tree_map(np.shape, params1))
As you found, _ = np.array(params1)
If you want to params1_stacked = jax.tree_map(lambda *x: jnp.stack(x), *params1)
out = vmap(get_act,(None,0,0,0))(system_params,sp,params1_stacked,obs)
print(out)
Here One side-note however: as written, your function likely will not work as expected within |
Beta Was this translation helpful? Give feedback.
You cannot use a simple
jnp.array(params1)
becauseparams1
is not a simple list of arrays, it is a nontrivial nested data structure:As you found,
np.array(params1)
returns an output, but it is an array withdtype=object
, which JAX doesn't support. Numpy actually warns you about this issue as well: