-
I am well aware that JAX and XLA do not support dynamically-shaped arrays. However, it's still not obvious to me whether that means all types of dynamic slicing are forbidden. In particular, I am trying to populate a statically-shaped array with values copied from another statically-shaped array, but where the indices of the source slice are dynamic. For example: import numpy as np
import jax.numpy as jnp
from jax import jit
from functools import partial
from jax import Array
def extract_slice(a: Array, indices: Array, row: int, max_size: int):
x= np.zeros(max_size)
start = indices[row]
stop = indices[row+1]
count = stop - start
x[0:count] = a[start:stop]
return x
extract_slice_jit = partial(jit, static_argnums={3})(extract_slice)
arr = np.arange(20)
ind = np.array([0, 5, 9, 12, 17, 19, 20])
print(extract_slice(arr, ind, 2, max_size=5))
# output: [ 9. 10. 11. 0. 0.]
print(extract_slice_jit(arr, ind, 2, max_size=5))
# raises IndexError I could solve this with a loop, but that seems inelegant: import numpy as np
import jax.numpy as jnp
from jax import jit
from functools import partial
from jax import Array
@partial(jit, static_argnums={3})
def extract_slice_jit(a: Array, indices: Array, row: int, max_size: int):
x = jnp.zeros(max_size)
start = indices[row]
stop = indices[row+1]
count = stop - start
x = jax.lax.fori_loop(0, count, lambda i, x: x.at[i].set(a[start+i]), x)
return x
arr = jnp.arange(20)
ind = jnp.array([0, 5, 9, 12, 17, 19, 20])
print(extract_slice_jit(arr, ind, 2, max_size=5))
# output: [ 9. 10. 11. 0. 0.] Is this really the best way to handle this kind of assignment? |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
You can't do this operation via standarding indexing assignments, because it would require constructing dynamically-shaped intermediate arrays. But you can do this operation using other approaches; for example this should work: def extract_slice(a: Array, indices: Array, row: int, max_size: int):
x = jnp.zeros(max_size)
start = indices[row]
stop = indices[row+1]
count = stop - start
a = jnp.roll(a, -start)[:max_size] # move relevant entries to front
return jnp.where(jnp.arange(max_size) < count, a, x) |
Beta Was this translation helpful? Give feedback.
You can't do this operation via standarding indexing assignments, because it would require constructing dynamically-shaped intermediate arrays. But you can do this operation using other approaches; for example this should work: