Skip to content

Does Muon save the state["momentum_buffer"] correctly in a distributed setting? #46

@EIFY

Description

@EIFY

Hi, I still need to test this but if I am reading this correctly

Muon/muon.py

Lines 79 to 94 in f90a42b

for group in self.param_groups:
params = group["params"]
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size())
for base_i in range(len(params))[::dist.get_world_size()]:
if base_i + dist.get_rank() < len(params):
p = params[base_i + dist.get_rank()]
if p.grad is None:
# continue
p.grad = torch.zeros_like(p) # Force synchronization
state = self.state[p]
if len(state) == 0:
state["momentum_buffer"] = torch.zeros_like(p)
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"])
p.mul_(1 - group["lr"] * group["weight_decay"])
p.add_(update.reshape(p.shape), alpha=-group["lr"])
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()])

It seems that each GPU updates state["momentum_buffer"] of its assigned parameters in-place and the parameters themselves are synced with all_gather(), but state["momentum_buffer"] isn't synced. Unless there is some automatic guarantee in place what may happen is that it works fine as long as there is no training interruption, but once you save & load only the state["momentum_buffer"] of the parameters handled by cuda:0 is correct...

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions