-
Notifications
You must be signed in to change notification settings - Fork 77
Open
Description
Hi, I still need to test this but if I am reading this correctly
Lines 79 to 94 in f90a42b
for group in self.param_groups: | |
params = group["params"] | |
params_pad = params + [torch.empty_like(params[-1])] * (dist.get_world_size() - len(params) % dist.get_world_size()) | |
for base_i in range(len(params))[::dist.get_world_size()]: | |
if base_i + dist.get_rank() < len(params): | |
p = params[base_i + dist.get_rank()] | |
if p.grad is None: | |
# continue | |
p.grad = torch.zeros_like(p) # Force synchronization | |
state = self.state[p] | |
if len(state) == 0: | |
state["momentum_buffer"] = torch.zeros_like(p) | |
update = muon_update(p.grad, state["momentum_buffer"], beta=group["momentum"]) | |
p.mul_(1 - group["lr"] * group["weight_decay"]) | |
p.add_(update.reshape(p.shape), alpha=-group["lr"]) | |
dist.all_gather(params_pad[base_i:base_i + dist.get_world_size()], params_pad[base_i + dist.get_rank()]) |
It seems that each GPU updates
state["momentum_buffer"]
of its assigned parameters in-place and the parameters themselves are synced with all_gather()
, but state["momentum_buffer"]
isn't synced. Unless there is some automatic guarantee in place what may happen is that it works fine as long as there is no training interruption, but once you save & load only the state["momentum_buffer"]
of the parameters handled by cuda:0
is correct...Metadata
Metadata
Assignees
Labels
No labels