Replies: 1 comment 2 replies
-
Your main issue is that you used ex: @partial(pmap, axis_name='devices')
def train_step(params, inputs, bboxes, labels, opt_state, clip_state, step):
(loss_val, params), grads = train_forward(params, inputs, bboxes, labels, step)
# ..... For pjit, you need to specifying how the data and parameters are sharded. from jax.experimental.pjit import pjit
from jax.sharding import NamedSharding, Mesh, PartitionSpec as P
@pjit(
in_shardings=(replicated_sharding, data_sharding, bboxes_sharding, labels_sharding, replicated_sharding, replicated_sharding, replicated_sharding),
out_shardings=(replicated_sharding, replicated_sharding, replicated_sharding, replicated_sharding)
)
def train_step(params, inputs, bboxes, labels, opt_state, clip_state, step):
(loss_val, params), grads = train_forward(params, inputs, bboxes, labels, step)
|
Beta Was this translation helpful? Give feedback.
2 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Hi Guys, I attempted to train a model using data parallelism with JAX. However, the speedup was a meager 1.14. Below is a portion of my code. Could you help me identify where I went wrong?
Beta Was this translation helpful? Give feedback.
All reactions