-
Hello! DescriptionI'm unsure if this question has been already answered or not but I lack of answers regarding my issue. Basically, I try to apply a mask to an array to filter values. I know it is already possible to do In fact I want to add transitions into a replay buffer in reinforcement learning, but some of my transitions are invalid and the invalid ones are defined by a mask vector of boolean. Codeimport jax
import jax.numpy as jnp
from functools import partial
import flax
num_envs = 3
@flax.struct.dataclass
class ReplayBufferState:
"""Contains data related to a replay buffer."""
data: jnp.ndarray
insert_position: int
sample_position: int
def init_replay_buffer_state(data_shape):
data = jnp.zeros(data_shape, dtype=jnp.float32)
return ReplayBufferState(data, 0, 0)
@partial(jax.jit)
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array) -> ReplayBufferState:
mask_size = jnp.sum(mask)
indices_mask = jnp.nonzero(mask, size=mask_size)[0]
# Apply the mask to the samples to keep the valid ones
new_samples = jnp.take(samples, indices_mask, axis=0)
samples_size = len(new_samples)
# Current buffer state
data = buffer_state.data
insert_idx = buffer_state.insert_position
size_buffer = buffer_state.sample_position
# Insert the new samples in the buffer
data = jax.lax.dynamic_update_slice_in_dim(data, new_samples, insert_idx, axis=0)
insert_idx = (insert_idx + samples_size) % size_buffer
sample_idx = jnp.minimum(buffer_state.sample_position + samples_size, size_buffer)
return buffer_state.replace(
data=data,
insert_position=insert_idx,
sample_position=sample_idx,
)
# Create a buffer state
buffer_state = init_replay_buffer_state((1000, 10))
# Dummy samples
samples = jnp.ones((num_envs, 10))
# Mask to insert only the first two samples
mask = jnp.array([True, True, False])
buffer_state = insert_in_replay_state(buffer_state, samples, mask) Output
I know one solution could be to specify the @partial(jax.jit, static_argnames="mask_size")
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array, mask_size: int) -> ReplayBufferState:
indices_mask = jnp.nonzero(mask, size=mask_size)[0] But in my code I already call the insert function in a jitted function, so I can't have access to any python value |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 2 replies
-
Hi - thanks for the question! The code you pasted calls |
Beta Was this translation helpful? Give feedback.
-
The problem is that your approach attempts to construct dynamically-sized arrays: So what you need to do here is express the update you have in mind without constructing any dynamically-shaped arrays. Here's an example of the kind of approach you might use: @partial(jax.jit)
def insert_in_replay_state(buffer_state: ReplayBufferState, samples: jax.Array, mask: jax.Array) -> ReplayBufferState:
# Padded indices of the mask elements
samples_size = jnp.sum(mask)
mask_indices = jnp.where(mask, size=len(mask), fill_value=len(mask))
# Current buffer state
data = buffer_state.data
insert_idx = buffer_state.insert_position
size_buffer = buffer_state.sample_position
# Create a copy of the buffer with samples inserted at insert_idx
data_indices = insert_idx + jnp.arange(len(mask))
update_mask = jnp.arange(len(mask))[:, None] < samples_size
data = data.at[data_indices].set(jnp.where(update_mask, samples[mask_indices], data[data_indices]))
insert_idx = (insert_idx + samples_size) % size_buffer
sample_idx = jnp.minimum(buffer_state.sample_position + samples_size, size_buffer)
return buffer_state.replace(
data=data,
insert_position=insert_idx,
sample_position=sample_idx,
) |
Beta Was this translation helpful? Give feedback.
-
I've played a little with the idea of replay buffers in JAX here, you might find it interesting |
Beta Was this translation helpful? Give feedback.
The problem is that your approach attempts to construct dynamically-sized arrays:
mask_size
depends on the contents of the traced arraymask
, and so it is a dynamic value. Because of this, it cannot be used in the staticsize
argument ofjnp.nonzero(mask, size=mask_size)
. For more on this, see JAX Sharp Bits: Dynamic Shapes.So what you need to do here is express the update you have in mind without constructing any dynamically-shaped arrays. Here's an example of the kind of approach you might use: