Using vmap() with batched lists of GraphTuples #22243
Replies: 1 comment
-
Fixed by using jnp.stack to stack my GraphsTuples! Updating in case anyone runs into a similar issue |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all,
To preface: I'm very new to JAX, so I apologize if this is a silly question. I've looked around for help re: GraphTuples and vmap (especially #16641 ), but I haven't had much luck. I'd appreciate any help!
I'm writing a two-layer GNN that currently predicts pretty well. However, it runs somewhat slow, so I'm looking for ways to optimize it on my GPU. I decided to implement batching during training; previously, my code just trained one time-series forecast window of data GraphTuples at a time. Now, I want to group
batch_size
time-series forecast windows into some number of batches, and usejax.vmap()
to train the windows in parallel.Here's a snippet of my code. I pass in my input and target batches, then use jax.vmap() to batch them into their time-series forecast windows:
Here's a very small version of what input_batch_graphs is structured like. In here,
batch_size
= 2, and each graph contains 6 nodes, each with 2 layers:When I try to run this code, I get the following error:
ValueError: vmap in_axes specification must be a tree prefix of the corresponding value, got specification (GraphsTuple(nodes=1, edges=1, receivers=1, senders=1, globals=1, n_node=None, n_edge=None), GraphsTuple(nodes=1, edges=1, receivers=1, senders=1, globals=1, n_node=None, n_edge=None)) for value tree PyTreeDef(([CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])], [CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *]), CustomNode(namedtuple[GraphsTuple], [*, *, *, *, *, *, *])])).
Is there a way to either use
in_axes
or reformat my batch data structure such that I can use jax.vmap() to batch over the windows? I'd also appreciate any other tips for batching/using vmap :)For more context, here are my jax/jraph versions:
Thanks,
Mia
Beta Was this translation helpful? Give feedback.
All reactions