Skip to content

Commit e93e571

Browse files
hellbellrwightman
authored andcommitted
Add adamp and 'sgdp' optimizers.
Update requirements.txt Update optim_factory.py Add `adamp` optimizer Update __init__.py copy files of adamp & sgdp Create adamp.py Update __init__.py Create sgdp.py Update optim_factory.py Update optim_factory.py Update requirements.txt Update adamp.py Update sgdp.py Update sgdp.py Update adamp.py
1 parent 0915bed commit e93e571

File tree

4 files changed

+214
-1
lines changed

4 files changed

+214
-1
lines changed

timm/optim/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,6 @@
55
from .novograd import NovoGrad
66
from .nvnovograd import NvNovoGrad
77
from .lookahead import Lookahead
8+
from .adamp import AdamP
9+
from .sgdp import SGDP
810
from .optim_factory import create_optimizer

timm/optim/adamp.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
"""
2+
AdamP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/adamp.py
3+
4+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5+
Code: https://github.com/clovaai/AdamP
6+
7+
Copyright (c) 2020-present NAVER Corp.
8+
MIT license
9+
"""
10+
11+
import torch
12+
import torch.nn as nn
13+
from torch.optim.optimizer import Optimizer, required
14+
import math
15+
16+
class AdamP(Optimizer):
17+
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
18+
weight_decay=0, delta=0.1, wd_ratio=0.1, nesterov=False):
19+
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay,
20+
delta=delta, wd_ratio=wd_ratio, nesterov=nesterov)
21+
super(AdamP, self).__init__(params, defaults)
22+
23+
def _channel_view(self, x):
24+
return x.view(x.size(0), -1)
25+
26+
def _layer_view(self, x):
27+
return x.view(1, -1)
28+
29+
def _cosine_similarity(self, x, y, eps, view_func):
30+
x = view_func(x)
31+
y = view_func(y)
32+
33+
x_norm = x.norm(dim=1).add_(eps)
34+
y_norm = y.norm(dim=1).add_(eps)
35+
dot = (x * y).sum(dim=1)
36+
37+
return dot.abs() / x_norm / y_norm
38+
39+
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40+
wd = 1
41+
expand_size = [-1] + [1] * (len(p.shape) - 1)
42+
for view_func in [self._channel_view, self._layer_view]:
43+
44+
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45+
46+
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47+
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49+
wd = wd_ratio
50+
51+
return perturb, wd
52+
53+
return perturb, wd
54+
55+
def step(self, closure=None):
56+
loss = None
57+
if closure is not None:
58+
loss = closure()
59+
60+
for group in self.param_groups:
61+
for p in group['params']:
62+
if p.grad is None:
63+
continue
64+
65+
grad = p.grad.data
66+
beta1, beta2 = group['betas']
67+
nesterov = group['nesterov']
68+
69+
state = self.state[p]
70+
71+
# State initialization
72+
if len(state) == 0:
73+
state['step'] = 0
74+
state['exp_avg'] = torch.zeros_like(p.data)
75+
state['exp_avg_sq'] = torch.zeros_like(p.data)
76+
77+
# Adam
78+
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
79+
80+
state['step'] += 1
81+
bias_correction1 = 1 - beta1 ** state['step']
82+
bias_correction2 = 1 - beta2 ** state['step']
83+
84+
exp_avg.mul_(beta1).add_(1 - beta1, grad)
85+
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
86+
87+
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(group['eps'])
88+
step_size = group['lr'] / bias_correction1
89+
90+
if nesterov:
91+
perturb = (beta1 * exp_avg + (1 - beta1) * grad) / denom
92+
else:
93+
perturb = exp_avg / denom
94+
95+
# Projection
96+
wd_ratio = 1
97+
if len(p.shape) > 1:
98+
perturb, wd_ratio = self._projection(p, grad, perturb, group['delta'], group['wd_ratio'], group['eps'])
99+
100+
# Weight decay
101+
if group['weight_decay'] > 0:
102+
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio)
103+
104+
# Step
105+
p.data.add_(-step_size, perturb)
106+
107+
return loss

timm/optim/optim_factory.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import torch
22
from torch import optim as optim
3-
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead
3+
from timm.optim import Nadam, RMSpropTF, AdamW, RAdam, NovoGrad, NvNovoGrad, Lookahead, AdamP, SGDP
44
try:
55
from apex.optimizers import FusedNovoGrad, FusedAdam, FusedLAMB, FusedSGD
66
has_apex = True
@@ -60,6 +60,14 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
6060
elif opt_lower == 'radam':
6161
optimizer = RAdam(
6262
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)
63+
elif opt_lower == 'adamp':
64+
optimizer = AdamP(
65+
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps,
66+
delta=0.1, wd_ratio=0.01, nesterov=True)
67+
elif opt_lower == 'sgdp':
68+
optimizer = SGDP(
69+
parameters, lr=args.lr, momentum=args.momentum, weight_decay=weight_decay,
70+
eps=args.opt_eps, nesterov=True)
6371
elif opt_lower == 'adadelta':
6472
optimizer = optim.Adadelta(
6573
parameters, lr=args.lr, weight_decay=weight_decay, eps=args.opt_eps)

timm/optim/sgdp.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""
2+
SGDP Optimizer Implementation copied from https://github.com/clovaai/AdamP/blob/master/adamp/sgdp.py
3+
4+
Paper: `Slowing Down the Weight Norm Increase in Momentum-based Optimizers` - https://arxiv.org/abs/2006.08217
5+
Code: https://github.com/clovaai/AdamP
6+
7+
Copyright (c) 2020-present NAVER Corp.
8+
MIT license
9+
"""
10+
11+
import torch
12+
import torch.nn as nn
13+
from torch.optim.optimizer import Optimizer, required
14+
import math
15+
16+
class SGDP(Optimizer):
17+
def __init__(self, params, lr=required, momentum=0, dampening=0,
18+
weight_decay=0, nesterov=False, eps=1e-8, delta=0.1, wd_ratio=0.1):
19+
defaults = dict(lr=lr, momentum=momentum, dampening=dampening, weight_decay=weight_decay,
20+
nesterov=nesterov, eps=eps, delta=delta, wd_ratio=wd_ratio)
21+
super(SGDP, self).__init__(params, defaults)
22+
23+
def _channel_view(self, x):
24+
return x.view(x.size(0), -1)
25+
26+
def _layer_view(self, x):
27+
return x.view(1, -1)
28+
29+
def _cosine_similarity(self, x, y, eps, view_func):
30+
x = view_func(x)
31+
y = view_func(y)
32+
33+
x_norm = x.norm(dim=1).add_(eps)
34+
y_norm = y.norm(dim=1).add_(eps)
35+
dot = (x * y).sum(dim=1)
36+
37+
return dot.abs() / x_norm / y_norm
38+
39+
def _projection(self, p, grad, perturb, delta, wd_ratio, eps):
40+
wd = 1
41+
expand_size = [-1] + [1] * (len(p.shape) - 1)
42+
for view_func in [self._channel_view, self._layer_view]:
43+
44+
cosine_sim = self._cosine_similarity(grad, p.data, eps, view_func)
45+
46+
if cosine_sim.max() < delta / math.sqrt(view_func(p.data).size(1)):
47+
p_n = p.data / view_func(p.data).norm(dim=1).view(expand_size).add_(eps)
48+
perturb -= p_n * view_func(p_n * perturb).sum(dim=1).view(expand_size)
49+
wd = wd_ratio
50+
51+
return perturb, wd
52+
53+
return perturb, wd
54+
55+
def step(self, closure=None):
56+
loss = None
57+
if closure is not None:
58+
loss = closure()
59+
60+
for group in self.param_groups:
61+
weight_decay = group['weight_decay']
62+
momentum = group['momentum']
63+
dampening = group['dampening']
64+
nesterov = group['nesterov']
65+
66+
for p in group['params']:
67+
if p.grad is None:
68+
continue
69+
grad = p.grad.data
70+
state = self.state[p]
71+
72+
# State initialization
73+
if len(state) == 0:
74+
state['momentum'] = torch.zeros_like(p.data)
75+
76+
# SGD
77+
buf = state['momentum']
78+
buf.mul_(momentum).add_(1 - dampening, grad)
79+
if nesterov:
80+
d_p = grad + momentum * buf
81+
else:
82+
d_p = buf
83+
84+
# Projection
85+
wd_ratio = 1
86+
if len(p.shape) > 1:
87+
d_p, wd_ratio = self._projection(p, grad, d_p, group['delta'], group['wd_ratio'], group['eps'])
88+
89+
# Weight decay
90+
if weight_decay != 0:
91+
p.data.mul_(1 - group['lr'] * group['weight_decay'] * wd_ratio / (1-momentum))
92+
93+
# Step
94+
p.data.add_(-group['lr'], d_p)
95+
96+
return loss

0 commit comments

Comments
 (0)