Skip to content

Is it possible to code a tensordot function that has dynamic axes? #26270

Answered by jakevdp
erick-xanadu asked this question in Q&A
Discussion options

You must be logged in to vote

I don't know of any built-in operation that allows tensor dot products with dynamic axes. Keep in mind that because of JAX's static shape requirements, this would only be possible in cases where the axes being selected all have the same size, so that the output will have the same shape regardless of the dynamic value.

Your best bet would probably be to do it manually using a cond or similar; for example:

import jax
import jax.numpy as jnp

@jax.jit
def dynamic_dot(x, y, axis):
  def f1(x, y):
    return x.T @ y
  def f2(x, y):
    return x @ y
  return jax.lax.cond(axis==0, f1, f2, x, y)

x = jnp.arange(9).reshape(3, 3)
y = jnp.ones(3)

print(dynamic_dot(x, y, axis=0))  # [ 9. 12. 15.]
print

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by erick-xanadu
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants