Skip to content

Add a binary mask to jax.Array #18517

Answered by jakevdp
vcharraut asked this question in Q&A
Nov 14, 2023 · 3 comments · 2 replies
Discussion options

You must be logged in to vote

The problem is that your approach attempts to construct dynamically-sized arrays: mask_size depends on the contents of the traced array mask, and so it is a dynamic value. Because of this, it cannot be used in the static size argument of jnp.nonzero(mask, size=mask_size). For more on this, see JAX Sharp Bits: Dynamic Shapes.

So what you need to do here is express the update you have in mind without constructing any dynamically-shaped arrays. Here's an example of the kind of approach you might use:

@partial(jax.jit)
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array) -> ReplayBufferState:
    # Padded indices of the mask elements
    samples_size =

Replies: 3 comments 2 replies

Comment options

You must be logged in to vote
1 reply
@vcharraut
Comment options

Comment options

You must be logged in to vote
1 reply
@vcharraut
Comment options

Answer selected by vcharraut
Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
3 participants