Skip to content

vmap on list? #14021

Answered by jakevdp
dui1234 asked this question in General
Jan 15, 2023 · 2 comments · 4 replies
Discussion options

You must be logged in to vote

You cannot use a simple jnp.array(params1) because params1 is not a simple list of arrays, it is a nontrivial nested data structure:

print(jax.tree_map(np.shape, params1))
[[((4, 128), (128,)), (), ((128, 128), (128,)), (), ((128, 2), (2,))],
 [((4, 128), (128,)), (), ((128, 128), (128,)), (), ((128, 2), (2,))]]

As you found, np.array(params1) returns an output, but it is an array with dtype=object, which JAX doesn't support. Numpy actually warns you about this issue as well:

_ = np.array(params1)
VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a list-or-tuple of
  lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you me…

Replies: 2 comments 4 replies

Comment options

You must be logged in to vote
3 replies
@dui1234
Comment options

@jakevdp
Comment options

@dui1234
Comment options

Comment options

You must be logged in to vote
1 reply
@dui1234
Comment options

Answer selected by dui1234
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
3 participants