Skip to content

Commit 7b54eab

Browse files
committed
Add MARS and LaProp impl, simplified from originals
1 parent e5aea35 commit 7b54eab

File tree

2 files changed

+288
-0
lines changed

2 files changed

+288
-0
lines changed

timm/optim/laprop.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
""" PyTorch impl of LaProp optimizer
2+
3+
Code simplified from https://github.com/Z-T-WANG/LaProp-Optimizer, MIT License
4+
5+
Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs/2002.04839
6+
7+
@article{ziyin2020laprop,
8+
title={LaProp: a Better Way to Combine Momentum with Adaptive Gradient},
9+
author={Ziyin, Liu and Wang, Zhikang T and Ueda, Masahito},
10+
journal={arXiv preprint arXiv:2002.04839},
11+
year={2020}
12+
}
13+
14+
"""
15+
from torch.optim import Optimizer
16+
import torch
17+
18+
19+
class LaProp(Optimizer):
20+
""" LaProp Optimizer
21+
22+
Paper: LaProp: Separating Momentum and Adaptivity in Adam, https://arxiv.org/abs/2002.04839
23+
"""
24+
def __init__(
25+
self,
26+
params,
27+
lr=4e-4,
28+
betas=(0.9, 0.999),
29+
eps=1e-15,
30+
weight_decay=0,
31+
):
32+
if not 0.0 <= lr:
33+
raise ValueError("Invalid learning rate: {}".format(lr))
34+
if not 0.0 <= eps:
35+
raise ValueError("Invalid epsilon value: {}".format(eps))
36+
if not 0.0 <= betas[0] < 1.0:
37+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
38+
if not 0.0 <= betas[1] < 1.0:
39+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
40+
defaults = dict(
41+
lr=lr,
42+
betas=betas,
43+
eps=eps,
44+
weight_decay=weight_decay,
45+
)
46+
super(LaProp, self).__init__(params, defaults)
47+
48+
@torch.no_grad()
49+
def step(self, closure=None):
50+
"""Performs a single optimization step.
51+
52+
Arguments:
53+
closure (callable, optional): A closure that reevaluates the model
54+
and returns the loss.
55+
"""
56+
loss = None
57+
if closure is not None:
58+
with torch.enable_grad():
59+
loss = closure()
60+
61+
for group in self.param_groups:
62+
for p in group['params']:
63+
if p.grad is None:
64+
continue
65+
grad = p.grad
66+
if grad.is_sparse:
67+
raise RuntimeError('LaProp does not support sparse gradients')
68+
69+
state = self.state[p]
70+
71+
# State initialization
72+
if len(state) == 0:
73+
state['step'] = 0
74+
# Exponential moving average of gradient values
75+
state['exp_avg'] = torch.zeros_like(p)
76+
# Exponential moving average of learning rates
77+
state['exp_avg_lr_1'] = 0.
78+
state['exp_avg_lr_2'] = 0.
79+
# Exponential moving average of squared gradient values
80+
state['exp_avg_sq'] = torch.zeros_like(p)
81+
82+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
83+
beta1, beta2 = group['betas']
84+
85+
state['step'] += 1
86+
one_minus_beta2 = 1 - beta2
87+
one_minus_beta1 = 1 - beta1
88+
89+
# Decay the first and second moment running average coefficient
90+
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=one_minus_beta2)
91+
92+
state['exp_avg_lr_1'] = state['exp_avg_lr_1'] * beta1 + one_minus_beta1 * group['lr']
93+
state['exp_avg_lr_2'] = state['exp_avg_lr_2'] * beta2 + one_minus_beta2
94+
95+
# 1 - beta1 ** state['step']
96+
bias_correction1 = state['exp_avg_lr_1'] / group['lr'] if group['lr'] != 0. else 1.
97+
bias_correction2 = state['exp_avg_lr_2']
98+
step_size = 1 / bias_correction1
99+
100+
denom = exp_avg_sq.div(bias_correction2).sqrt_().add_(group['eps'])
101+
step_of_this_grad = grad / denom
102+
exp_avg.mul_(beta1).add_(step_of_this_grad, alpha=group['lr'] * one_minus_beta1)
103+
104+
p.add_(exp_avg, alpha=-step_size)
105+
if group['weight_decay'] != 0:
106+
p.add_(p, alpha=-group['weight_decay'])
107+
108+
return loss

timm/optim/mars.py

Lines changed: 180 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,180 @@
1+
""" PyTorch MARS Optimizer
2+
3+
Code simplified from https://github.com/AGI-Arena/MARS
4+
5+
Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models - https://arxiv.org/abs/2411.10438
6+
7+
@article{yuan2024mars,
8+
title={MARS: Unleashing the Power of Variance Reduction for Training Large Models},
9+
author={Yuan, Huizhuo and Liu, Yifeng and Wu, Shuang and Zhou, Xun and Gu, Quanquan},
10+
journal={arXiv preprint arXiv:2411.10438},
11+
year={2024}
12+
}
13+
"""
14+
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
15+
# SPDX-License-Identifier: Apache-2.0
16+
import math
17+
18+
import torch
19+
from torch.optim.optimizer import Optimizer
20+
21+
22+
def mars_single_tensor(
23+
p,
24+
grad,
25+
exp_avg,
26+
exp_avg_sq,
27+
lr,
28+
weight_decay,
29+
beta1,
30+
beta2,
31+
last_grad,
32+
eps,
33+
step,
34+
gamma,
35+
mars_type,
36+
is_grad_2d,
37+
optimize_1d,
38+
lr_1d_factor,
39+
betas_1d,
40+
):
41+
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
42+
if optimize_1d or is_grad_2d:
43+
one_minus_beta1 = 1. - beta1
44+
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
45+
c_t_norm = torch.norm(c_t)
46+
if c_t_norm > 1.:
47+
c_t = c_t / c_t_norm
48+
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
49+
if mars_type == "adamw":
50+
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
51+
bias_correction1 = 1.0 - beta1 ** step
52+
bias_correction2 = 1.0 - beta2 ** step
53+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
54+
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
55+
elif mars_type == "lion":
56+
update = p * weight_decay + exp_avg.sign()
57+
else:
58+
assert False
59+
p.add_(update, alpha=-lr)
60+
else:
61+
beta1_1d, beta2_1d = betas_1d
62+
exp_avg.mul_(beta1_1d).add_(grad, alpha=1. - beta1_1d)
63+
exp_avg_sq.mul_(beta2_1d).addcmul_(grad, grad, value=1. - beta2_1d)
64+
bias_correction1 = 1.0 - beta1_1d ** step
65+
bias_correction2 = 1.0 - beta2_1d ** step
66+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
67+
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
68+
p.add_(update, alpha=-(lr * lr_1d_factor))
69+
return exp_avg, exp_avg_sq
70+
71+
72+
class Mars(Optimizer):
73+
""" MARS Optimizer
74+
75+
Paper: MARS: Unleashing the Power of Variance Reduction for Training Large Models
76+
https://arxiv.org/abs/2411.10438
77+
78+
"""
79+
def __init__(
80+
self,
81+
params,
82+
lr=3e-3,
83+
betas=(0.9, 0.99),
84+
eps=1e-8,
85+
weight_decay=0.,
86+
gamma=0.025,
87+
mars_type="adamw",
88+
optimize_1d=False,
89+
lr_1d_factor=1.0,
90+
betas_1d=None,
91+
):
92+
if not 0.0 <= lr:
93+
raise ValueError("Invalid learning rate: {}".format(lr))
94+
if not 0.0 <= eps:
95+
raise ValueError("Invalid epsilon value: {}".format(eps))
96+
if not 0.0 <= betas[0] < 1.0:
97+
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
98+
if not 0.0 <= betas[1] < 1.0:
99+
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
100+
assert mars_type in ["adamw", "lion"], "MARS type not supported"
101+
102+
defaults = dict(
103+
lr=lr,
104+
betas=betas,
105+
eps=eps,
106+
weight_decay=weight_decay,
107+
mars_type=mars_type,
108+
gamma=gamma,
109+
optimize_1d=optimize_1d,
110+
lr_1d_factor=lr_1d_factor,
111+
betas_1d=betas_1d or betas,
112+
)
113+
super(Mars, self).__init__(params, defaults)
114+
115+
@torch.no_grad()
116+
def step(self, closure=None):
117+
"""Performs a single optimization step.
118+
119+
Arguments:
120+
closure (callable, optional): A closure that reevaluates the model
121+
and returns the loss.
122+
"""
123+
loss = None
124+
if closure is not None:
125+
with torch.enable_grad():
126+
loss = closure()
127+
128+
for group in self.param_groups:
129+
for p in group['params']:
130+
if p.grad is None:
131+
continue
132+
grad = p.grad
133+
if grad.is_sparse:
134+
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
135+
136+
state = self.state[p]
137+
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
138+
# State initialization
139+
if len(state) <= 1:
140+
state['step'] = 0
141+
# Exponential moving average of gradient values
142+
state['exp_avg'] = torch.zeros_like(p)
143+
# Last Gradient
144+
state['last_grad'] = torch.zeros_like(p)
145+
# Exponential moving average of squared gradient values
146+
state['exp_avg_sq'] = torch.zeros_like(p)
147+
148+
state['step'] += 1
149+
step = state['step']
150+
exp_avg = state['exp_avg']
151+
exp_avg_sq = state['exp_avg_sq']
152+
last_grad = state['last_grad']
153+
lr = group['lr']
154+
wd = group['weight_decay']
155+
beta1, beta2 = group['betas']
156+
is_grad_2d = grad.ndim >= 2
157+
158+
mars_single_tensor(
159+
p,
160+
grad,
161+
exp_avg,
162+
exp_avg_sq,
163+
lr,
164+
wd,
165+
beta1,
166+
beta2,
167+
last_grad,
168+
group['eps'],
169+
step,
170+
group['gamma'],
171+
mars_type=group['mars_type'],
172+
is_grad_2d=is_grad_2d,
173+
optimize_1d=group['optimize_1d'],
174+
lr_1d_factor=group['lr_1d_factor'],
175+
betas_1d=group['betas_1d'],
176+
)
177+
178+
state['last_grad'] = grad
179+
180+
return loss

0 commit comments

Comments
 (0)