Since when is this form of dynamic indexing possible under jit? #15223
-
Out of curiosity: Has this always been possible? I thought this is where @jax.jit
def f(idxs, data):
res = 0
for i in idxs:
res += data[i]
return res
f(jnp.array([0, 1, 0]), jnp.array([1, 3])) |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 1 reply
-
I think yes, as the shape of |
Beta Was this translation helpful? Give feedback.
-
I believe this kind of dynamic indexing ( |
Beta Was this translation helpful? Give feedback.
-
I feel like you can create some surprising behaviour imo. @jax.jit
def f(parent_array):
# now it works..
ys = jnp.array([1, 2, 3])
# now it doesn't
ys = [1, 2, 3]
res = 0
for i in range(len(parent_array)):
parent = parent_array[i]
res += ys[parent]
return res
f(jnp.array([0, 1, 2]),) I somehow expected this to either not work or work both times. |
Beta Was this translation helpful? Give feedback.
I believe this kind of dynamic indexing (
x[i]
wherex
andi
are traced, andi
is integer-typed) has always been supported in JIT, but dynamic iteration (for i in idxs
whereidxs
is traced) has only been possible since #8043.