-
I am trying to optimize a piece of code by trying to JIT it. The code is as follows:
What I am trying to achieve is to select elements from x or y based on the mask. If x was broadcastable to the shape of mask, I could have used |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
It looks like what you're after are the exact semantics of import jax.numpy as jnp
output = jnp.place(
jnp.full(len(mask), y[0]),
jnp.array(mask, dtype=bool),
jnp.array(x), inplace=False) This was only recently implemented in JAX (it will be part of the 0.4.15 release) but until that's available the implementation is pretty simple: https://github.com/google/jax/blob/d7940ee9a11d27fa5dc745e23e6f82aea25ac525/jax/_src/numpy/lax_numpy.py#L4985-L4987 For your particular case, you could use something like this: mask = jnp.array(mask, dtype=bool)
x = jnp.array(x)
updates = jnp.zeros_like(x, shape=len(mask)).at[:len(x)].set(x)
indices = jnp.where(mask, size=mask.size, fill_value=mask.size)[0]
output = jnp.full(len(mask), y[0]).at[indices].set(updates, mode='drop') |
Beta Was this translation helpful? Give feedback.
It looks like what you're after are the exact semantics of
jax.numpy.place
:This was only recently implemented in JAX (it will be part of the 0.4.15 release) but until that's available the implementation is pretty simple: https://github.com/google/jax/blob/d7940ee9a11d27fa5dc745e23e6f82aea25ac525/jax/_src/numpy/lax_numpy.py#L4985-L4987
For your particular case, you could use something like this: