Replies: 1 comment
-
You should jit your call t from functools import partial
@partial(jit, static_argnums=(1, 2))
def choice(key, a: int, shape: tuple[int, ...]):
return jax.random.choice(key, a=a, shape=shape, replace=False)
def sample_batch(self, key):
# pretty sure you need conversion to np.ndarray here
idxs = np.array(choice(key, a=self.size, shape=(self.batch_size,)))
return [
self.obs_buf[idxs],
self.next_obs_buf[idxs],
self.acts_buf[idxs],
self.rews_buf[idxs],
self.done_buf[idxs],
# for N-step Learning
idxs
] BTW, it would help people help you better if you could provide a minimal reproducible example. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Hi,
I have a function that return a batch of samples from dataset:
and this function is called under another function:
This works perfectly find.
However, once I replace the
idxs = np.random.choice(self.size, size=self.batch_size, replace=False)
with
jax.random.choice
:and slightly change the
get_samples
function as:The speed of the entile code is sharply dropped. Just wondering what should I do in this case?
Beta Was this translation helpful? Give feedback.
All reactions