Skip to content

Commit 397f00d

Browse files
authored
Merge pull request #939 from tianshijing/main
Added Muon Optimizer
2 parents 6fc0fd8 + b4d2ec4 commit 397f00d

File tree

5 files changed

+116
-3
lines changed

5 files changed

+116
-3
lines changed

scripts/run_finetune_with_custom_optim.sh

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ elif [ "${optim}" == "adadelta" ]; then
252252
elif [ "${optim}" == "adagrad" ]; then
253253
optim_suffix_args="--use_customized_optim 1"
254254
optim_suffix_args+=" --customized_optim ${optim}"
255+
elif [ "${optim}" == "muon" ]; then
256+
optim_suffix_args="--use_customized_optim 1"
257+
optim_suffix_args+=" --optim_beta1 ${beta1}"
258+
optim_suffix_args+=" --optim_beta2 ${beta2}"
259+
optim_suffix_args+=" --optim_weight_decay ${weight_decay}"
260+
optim_suffix_args+=" --customized_optim ${optim}"
255261
elif [ "${optim}" == "adamw_schedule_free" ]; then
256262
optim_suffix_args="--use_customized_optim 1"
257263
optim_suffix_args+=" --customized_optim ${optim}"

src/lmflow/args.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class OptimizerNames():
4949
NOVOGRAD = "novograd"
5050
ADADELTA = "adadelta"
5151
ADAGRAD = "adagrad"
52+
MUON = "muon"
5253
ADAMW_SCHEDULE_FREE = "adamw_schedule_free"
5354
SGD_SCHEDULE_FREE = "sgd_schedule_free"
5455

@@ -1479,4 +1480,4 @@ def get_pipeline_args_class(pipeline_name: str):
14791480

14801481

14811482
def split_args(args):
1482-
return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args
1483+
return [elem.strip() for elem in args.split(",")] if isinstance(args, str) else args

src/lmflow/optim/muon.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
#!/usr/bin/env python
2+
# -*- coding: utf-8 -*-
3+
import torch
4+
import torch.nn as nn
5+
import math
6+
import os
7+
import torch.distributed as dist
8+
import torch.nn as nn
9+
from torch import Tensor
10+
def zeropower_via_newtonschulz5(G: Tensor, steps: int) -> Tensor:
11+
"""
12+
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
13+
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
14+
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
15+
zero even beyond the point where the iteration no longer converges all the way to one everywhere
16+
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
17+
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
18+
performance at all relative to UV^T, where USV^T = G is the SVD.
19+
"""
20+
assert G.ndim >= 2 # batched Muon implementation by @scottjmaddox, and put into practice in the record by @YouJiacheng
21+
a, b, c = (3.4445, -4.7750, 2.0315)
22+
X = G.bfloat16()
23+
if G.size(-2) > G.size(-1):
24+
X = X.mT
25+
26+
# Ensure spectral norm is at most 1
27+
X = X / (X.norm(dim=(-2, -1), keepdim=True) + 1e-7)
28+
# Perform the NS iterations
29+
for _ in range(steps):
30+
A = X @ X.mT
31+
B = b * A + c * A @ A # quintic computation strategy adapted from suggestion by @jxbz, @leloykun, and @YouJiacheng
32+
X = a * X + B @ X
33+
34+
if G.size(-2) > G.size(-1):
35+
X = X.mT
36+
return X
37+
class Muon(torch.optim.Optimizer):
38+
"""
39+
Adam optimizer with orthogonalization step.
40+
"""
41+
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, ns_steps=5):
42+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, ns_steps=ns_steps)
43+
super().__init__(params, defaults)
44+
45+
@torch.no_grad()
46+
def step(self, closure=None):
47+
"""
48+
Performs a single optimization step.
49+
50+
Args:
51+
closure (callable, optional): A closure that reevaluates the model
52+
and returns the loss.
53+
"""
54+
loss = None
55+
if closure is not None:
56+
loss = closure()
57+
58+
for group in self.param_groups:
59+
for p in group['params']:
60+
if p.grad is None:
61+
continue
62+
grad = p.grad
63+
state = self.state[p]
64+
65+
# Initialize state
66+
if len(state) == 0:
67+
state['step'] = 0
68+
state['exp_avg'] = torch.zeros_like(p)
69+
state['exp_avg_sq'] = torch.zeros_like(p)
70+
71+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
72+
beta1, beta2 = group['betas']
73+
74+
state['step'] += 1
75+
bias_correction1 = 1 - beta1 ** state['step']
76+
bias_correction2 = 1 - beta2 ** state['step']
77+
78+
# Update momentum and squared gradient
79+
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
80+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
81+
82+
# Compute the update
83+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
84+
step_size = group['lr'] / bias_correction1
85+
86+
# Orthogonalize the update
87+
update = exp_avg / denom
88+
if update.ndim >= 2:
89+
update = zeropower_via_newtonschulz5(update, steps=group['ns_steps'])
90+
91+
# Apply the update
92+
p.add_(update, alpha=-step_size)
93+
94+
# Apply weight decay
95+
if group['weight_decay'] != 0:
96+
p.add_(p, alpha=-group['lr'] * group['weight_decay'])
97+
98+
return loss

src/lmflow/optim/optimizers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,6 @@
1919
from lmflow.optim.adam import Adam
2020
from lmflow.optim.adadelta import Adadelta
2121
from lmflow.optim.adagrad import AdaGrad
22+
from lmflow.optim.muon import Muon
2223
from lmflow.optim.adamw_schedule_free import AdamWScheduleFree
23-
from lmflow.optim.sgd_schedule_free import SGDScheduleFree
24+
from lmflow.optim.sgd_schedule_free import SGDScheduleFree

src/lmflow/pipeline/finetuner.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,13 @@ def get_optimizer_cls_and_kwargs(
344344
adagrad_kwargs = {
345345
}
346346
optimizer_kwargs.update(adagrad_kwargs)
347+
elif args.customized_optim == OptimizerNames.MUON:
348+
optimizer_cls = optim.Muon
349+
muon_kwargs = {
350+
"betas": (args.optim_beta1, args.optim_beta2),
351+
"weight_decay": (args.optim_weight_decay),
352+
}
353+
optimizer_kwargs.update(muon_kwargs)
347354
elif args.customized_optim == OptimizerNames.ADAMW_SCHEDULE_FREE:
348355
optimizer_cls = optim.AdamWScheduleFree
349356
adamw_schedule_free_kwargs = {
@@ -640,4 +647,4 @@ def switch_active_layers(self):
640647
else:
641648
trainer.create_model_card(**kwargs)
642649

643-
return model
650+
return model

0 commit comments

Comments
 (0)