Creating multiple workers to do DQN #17755
Unanswered
raymondchua
asked this question in
Q&A
Replies: 1 comment
-
Hi, It seems like you can For examples,
Here is a simple structure I use def compute_fn(params, input):
grad = <compute grad here>
# note that `axis_name` should be the same with `pmap`
grad = jax.lax.pmean(grad, axis_name="batch")
return grad
# make model available to all devices
params = <model>
params = replicate(params)
# make `pmap` version of `compute_fn`
p_compute_fn = jax.pmap(compute_fn, axis_name="batch", donate_argnums=(0,))
# shard data across all devices
data = <input data>
data = shard(data)
grad = p_compute_fn(params, data)
# `grad` is still replicated across the devices, can be retrieved as
grad = unreplicate(grad) This is what I usually do. Maybe there is a better approach out there. Hope this could help. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi all, I hope this is the right place to ask this question. I would like be able to collect experience from 64 parallel environment copies into a single replay buffer. For exampling, using 8 parallel workers, using shared model parameters and averaging gradients between workers. Does anyone know if this can be done in Jax using vmap and jit? Perhaps similar to this: #11565 (reply in thread)
Beta Was this translation helpful? Give feedback.
All reactions