Is it possible to code a tensordot function that has dynamic axes? #26270
-
I found that jax.lax.dot_general also has the requirement of static dimension numbers, and I was wondering if there is already an implementation of tensordot with dynamic axes. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
I don't know of any built-in operation that allows tensor dot products with dynamic axes. Keep in mind that because of JAX's static shape requirements, this would only be possible in cases where the axes being selected all have the same size, so that the output will have the same shape regardless of the dynamic value. Your best bet would probably be to do it manually using a import jax
import jax.numpy as jnp
@jax.jit
def dynamic_dot(x, y, axis):
def f1(x, y):
return x.T @ y
def f2(x, y):
return x @ y
return jax.lax.cond(axis==0, f1, f2, x, y)
x = jnp.arange(9).reshape(3, 3)
y = jnp.ones(3)
print(dynamic_dot(x, y, axis=0)) # [ 9. 12. 15.]
print(dynamic_dot(x, y, axis=1)) # [ 3. 12. 21.]
`` |
Beta Was this translation helpful? Give feedback.
I don't know of any built-in operation that allows tensor dot products with dynamic axes. Keep in mind that because of JAX's static shape requirements, this would only be possible in cases where the axes being selected all have the same size, so that the output will have the same shape regardless of the dynamic value.
Your best bet would probably be to do it manually using a
cond
or similar; for example: