Batched jax.lax.gather #18198
Unanswered
andreasveit
asked this question in
Q&A
Replies: 1 comment
-
In general, I'd avoid calling In your case, it would look like this: result = embs[embs_indices] Regarding how to "add a leading batch dimension to this" – I'm not entirely clear on what you want to do. But if you're asking about adding a batch dimension to both embs_batched = jnp.stack([embs + i for i in range(4)])
embs_indices_batched = jnp.stack([embs_indices + i for i in range(4)])
batched_result = jax.vmap(lambda a, b: a[b])(embs_batched, embs_indices_batched) The |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
I have a tensor of batched embeddings of shape
embs.shape = [batch_size, num_embeddings_per_example, embedding_dim]
I further have a tensor of batched indices into this embedding tensor
embs_indices.shape = [batch_size, num_indices_per_example]
num_indices_per_example <= num_embeddings_per_example
I would like to gather the embeddings from embs into a new tensor of shape
output.shape = [batch_size, num_indices_per_example, embedding_dim]
Similar to the pytorch operation of embs.gather(dim=-2, index=embs_indices)
I have looked into jax.lax.gather, but have not found a way to specify batch dimensions.
For a single example, I was able to get this to work:
Providing the following result.
Any recommendation for how to add a leading batch dimension to this?
Beta Was this translation helpful? Give feedback.
All reactions