vmap over a list of class #17165
Unanswered
jing-alice
asked this question in
Q&A
Replies: 1 comment
-
Hi - thanks for the question! Your list of classes is what's known as an array-of-structs pattern. On the other hand, data_array_of_structs = [Comp(1), Comp(2), Comp(3)] you should instead define a structure like this: data_struct_of_arrays = Comp(jnp.array([1, 2, 3])) Then, as long as jax.vmap(lambda comp: comp())(data_struct_of_arrays) Hope that helps! |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
vmap does not currently fit into a list of classes and will automatically expand all the attributes of the class, causing an error in np.shape(x) in vmap.
for example:
Also, I have tried the register_pytree_node in jax.tree_util, but it doesn't work.
Beta Was this translation helpful? Give feedback.
All reactions