Replies: 1 comment
-
You'll probably find that the second approach will be more efficient, because it avoids mapped indexing operations over the arrays. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Sup Jax,
I have an interesting usecase - I have a struct that contains a 4-dimensional tensor with millions of elements - in which tensor contains 17 physical properties of elements.
On each of those physical properties, or a set of them, a function is done. For example, gravity(), computeVelocity(), computeStrain(), etc.
In Jax, would it be more efficient to pass references to the entire axis of a tensor and fetch data in the functions, or would it be more efficient to fetch each outside of the functions (in the call)
I use a 4-dimensional
vmap
for computation, like `vmap(vmap(vmap(vmap(fun))))(a,b,c)To clarify:
or
on them?
How is this actually managed in jax? would there be any difference?
Thanks.
Beta Was this translation helpful? Give feedback.
All reactions