Replies: 1 comment
-
JAX follows StableHLO/OpenXLA semantics for
As for why it does this: it's a straightforward convention. One can imagine other straightforward conventions, but this is the one that was chosen by the designers of XLA. I suspect the reason is that in ML computations over C-ordered arrays, batch dimensions are typically laid out in the leading dimension. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
I am currently trying to understand how
dot_general
works. I noticed something that strikes me as a bit strange:I would have expected the result to have a shape of (2, 5, 7, 3), i.e. the final batch dimension is still in the same place. From a performance standpoint it may well make sense to prepend all the batch dimensions, but surely jax could just put the dimension "back" where it belongs after computation. Is there a reason why it works this way?
Beta Was this translation helpful? Give feedback.
All reactions