Replies: 1 comment
-
You might try using |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Suppose you have a dataset of particle configurations, where each configuration is an array of positions of shape
(n, d)
wheren ∈ {0, ..., n_max}
is finite andd
is constant. Consider you have a single-sample loss functionloss_sample
that computes a scalar loss for each configuration. Sincen
varies per example, we cannot simplyvmap
loss_sample
over the whole batch of configurations. But, sincen_max
is finite, and most likely much less than the number of examples (n_max << n_configs
) in the dataset, it might be reasonable tojit
theloss_sample
function for each example (with JAX maintaining a cache of compiles). Hence,loss_sample
will be compiled at mostn_max
times. But, how can we then create aloss_batch
that efficiently maps thesejit
edloss_sample
functions over a batch of arrays with differentn
? If it is possible to do this efficiently, would it be a viable alternative to padding/masking? See below.Beta Was this translation helpful? Give feedback.
All reactions