-
Notifications
You must be signed in to change notification settings - Fork 678
Closed
Labels
discussionStart a discussionStart a discussion
Description
Continuing the discussion from #2172 (thanks @mirceamironenco, @ebsmothers for the fix!).
We have a run on the exact same dataset / hparams except we change the number of nodes from 8->2->1. We noticed that when we reduce the number of nodes the gradient norm goes up:
We can see the grad norm at initialization is ~4x different between 8 node and 1 node run. With the fix in #2172, I would expect the grad norms to be similar regardless of the world size. The only difference between the runs is the global batch size (64 on 1 node, 512 on 8 nodes), but I would not expect this to cause such a big difference.
Is it possible there are still some issues in how we compute / scale the gradients?
jianqunppl
Metadata
Metadata
Assignees
Labels
discussionStart a discussionStart a discussion