Skip to content

Grad Norm Differences Across Nodes #2240

@EugenHotaj

Description

@EugenHotaj

Continuing the discussion from #2172 (thanks @mirceamironenco, @ebsmothers for the fix!).

We have a run on the exact same dataset / hparams except we change the number of nodes from 8->2->1. We noticed that when we reduce the number of nodes the gradient norm goes up:

Here is an 8 node run:
Screenshot 2025-01-09 at 12 06 30 PM

Here is a 2 node run:
Screenshot 2025-01-09 at 12 05 41 PM

Here is a 1 node run:
Screenshot 2025-01-09 at 12 05 48 PM

We can see the grad norm at initialization is ~4x different between 8 node and 1 node run. With the fix in #2172, I would expect the grad norms to be similar regardless of the world size. The only difference between the runs is the global batch size (64 on 1 node, 512 on 8 nodes), but I would not expect this to cause such a big difference.

Is it possible there are still some issues in how we compute / scale the gradients?

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions