Skip to content

Equivalent of jnp.where for arguments of different lengths #16962

Answered by jakevdp
SNMS95 asked this question in Q&A
Discussion options

You must be logged in to vote

It looks like what you're after are the exact semantics of jax.numpy.place:

import jax.numpy as jnp
output = jnp.place(
    jnp.full(len(mask), y[0]),
    jnp.array(mask, dtype=bool),
    jnp.array(x), inplace=False)

This was only recently implemented in JAX (it will be part of the 0.4.15 release) but until that's available the implementation is pretty simple: https://github.com/google/jax/blob/d7940ee9a11d27fa5dc745e23e6f82aea25ac525/jax/_src/numpy/lax_numpy.py#L4985-L4987

For your particular case, you could use something like this:

mask = jnp.array(mask, dtype=bool)
x = jnp.array(x)
updates = jnp.zeros_like(x, shape=len(mask)).at[:len(x)].set(x)
indices = jnp.where(mask, size=mask.size, f…

Replies: 1 comment 1 reply

Comment options

You must be logged in to vote
1 reply
@SNMS95
Comment options

Answer selected by SNMS95
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
2 participants