Skip to content

Commit 17a47c0

Browse files
committed
Add SGDW optimizer
1 parent 2597ce2 commit 17a47c0

File tree

2 files changed

+263
-0
lines changed

2 files changed

+263
-0
lines changed

timm/optim/optim_factory.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from .radam import RAdam
2828
from .rmsprop_tf import RMSpropTF
2929
from .sgdp import SGDP
30+
from .sgdw import SGDW
3031

3132

3233
_logger = logging.getLogger(__name__)
@@ -288,6 +289,13 @@ def create_optimizer_v2(
288289
optimizer = optim.SGD(parameters, momentum=momentum, nesterov=False, **opt_args)
289290
elif opt_lower == 'sgdp':
290291
optimizer = SGDP(parameters, momentum=momentum, nesterov=True, **opt_args)
292+
elif opt_lower == 'sgdw' or opt_lower == 'nesterovw':
293+
# NOTE 'sgd' refers to SGD + nesterov momentum for legacy / backwards compat reasons
294+
opt_args.pop('eps', None)
295+
optimizer = SGDW(parameters, momentum=momentum, nesterov=True, **opt_args)
296+
elif opt_lower == 'momentumw':
297+
opt_args.pop('eps', None)
298+
optimizer = SGDW(parameters, momentum=momentum, nesterov=False, **opt_args)
291299

292300
# adaptive
293301
elif opt_lower == 'adam':

timm/optim/sgdw.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
import torch
2+
from torch import Tensor
3+
from torch.optim.optimizer import Optimizer, _use_grad_for_differentiable, _default_to_fused_or_foreach
4+
from typing import List, Optional
5+
6+
__all__ = ['SGDW', 'sgdw']
7+
8+
9+
class SGDW(Optimizer):
10+
def __init__(
11+
self,
12+
params,
13+
lr=1e-3,
14+
momentum=0,
15+
dampening=0,
16+
weight_decay=0,
17+
nesterov=False,
18+
*,
19+
maximize: bool = False,
20+
foreach: Optional[bool] = None,
21+
differentiable: bool = False,
22+
):
23+
if lr < 0.0:
24+
raise ValueError(f"Invalid learning rate: {lr}")
25+
if momentum < 0.0:
26+
raise ValueError(f"Invalid momentum value: {momentum}")
27+
if weight_decay < 0.0:
28+
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
29+
30+
defaults = dict(
31+
lr=lr, momentum=momentum, dampening=dampening,
32+
weight_decay=weight_decay, nesterov=nesterov,
33+
maximize=maximize, foreach=foreach,
34+
differentiable=differentiable)
35+
if nesterov and (momentum <= 0 or dampening != 0):
36+
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
37+
super().__init__(params, defaults)
38+
39+
def __setstate__(self, state):
40+
super().__setstate__(state)
41+
for group in self.param_groups:
42+
group.setdefault('nesterov', False)
43+
group.setdefault('maximize', False)
44+
group.setdefault('foreach', None)
45+
group.setdefault('differentiable', False)
46+
47+
def _init_group(self, group, params_with_grad, d_p_list, momentum_buffer_list):
48+
has_sparse_grad = False
49+
50+
for p in group['params']:
51+
if p.grad is not None:
52+
params_with_grad.append(p)
53+
d_p_list.append(p.grad)
54+
if p.grad.is_sparse:
55+
has_sparse_grad = True
56+
57+
state = self.state[p]
58+
if 'momentum_buffer' not in state:
59+
momentum_buffer_list.append(None)
60+
else:
61+
momentum_buffer_list.append(state['momentum_buffer'])
62+
63+
return has_sparse_grad
64+
65+
@_use_grad_for_differentiable
66+
def step(self, closure=None):
67+
"""Performs a single optimization step.
68+
69+
Args:
70+
closure (Callable, optional): A closure that reevaluates the model
71+
and returns the loss.
72+
"""
73+
loss = None
74+
if closure is not None:
75+
with torch.enable_grad():
76+
loss = closure()
77+
78+
for group in self.param_groups:
79+
params_with_grad = []
80+
d_p_list = []
81+
momentum_buffer_list = []
82+
83+
has_sparse_grad = self._init_group(group, params_with_grad, d_p_list, momentum_buffer_list)
84+
85+
sgdw(
86+
params_with_grad,
87+
d_p_list,
88+
momentum_buffer_list,
89+
weight_decay=group['weight_decay'],
90+
momentum=group['momentum'],
91+
lr=group['lr'],
92+
dampening=group['dampening'],
93+
nesterov=group['nesterov'],
94+
maximize=group['maximize'],
95+
has_sparse_grad=has_sparse_grad,
96+
foreach=group['foreach'],
97+
)
98+
99+
# update momentum_buffers in state
100+
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
101+
state = self.state[p]
102+
state['momentum_buffer'] = momentum_buffer
103+
104+
return loss
105+
106+
107+
def sgdw(
108+
params: List[Tensor],
109+
d_p_list: List[Tensor],
110+
momentum_buffer_list: List[Optional[Tensor]],
111+
# kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
112+
# setting this as kwarg for now as functional API is compiled by torch/distributed/optim
113+
has_sparse_grad: bool = None,
114+
foreach: Optional[bool] = None,
115+
*,
116+
weight_decay: float,
117+
momentum: float,
118+
lr: float,
119+
dampening: float,
120+
nesterov: bool,
121+
maximize: bool
122+
):
123+
r"""Functional API that performs SGD algorithm computation.
124+
125+
See :class:`~torch.optim.SGD` for details.
126+
"""
127+
128+
if foreach is None:
129+
# why must we be explicit about an if statement for torch.jit.is_scripting here?
130+
# because JIT can't handle Optionals nor fancy conditionals when scripting
131+
if not torch.jit.is_scripting():
132+
_, foreach = _default_to_fused_or_foreach(params, differentiable=False, use_fused=False)
133+
else:
134+
foreach = False
135+
136+
if foreach and torch.jit.is_scripting():
137+
raise RuntimeError('torch.jit.script not supported with foreach optimizers')
138+
139+
if foreach and not torch.jit.is_scripting():
140+
func = _multi_tensor_sgdw
141+
else:
142+
func = _single_tensor_sgdw
143+
144+
func(
145+
params,
146+
d_p_list,
147+
momentum_buffer_list,
148+
weight_decay=weight_decay,
149+
momentum=momentum,
150+
lr=lr,
151+
dampening=dampening,
152+
nesterov=nesterov,
153+
has_sparse_grad=has_sparse_grad,
154+
maximize=maximize,
155+
)
156+
157+
158+
def _single_tensor_sgdw(
159+
params: List[Tensor],
160+
d_p_list: List[Tensor],
161+
momentum_buffer_list: List[Optional[Tensor]],
162+
*,
163+
weight_decay: float,
164+
momentum: float,
165+
lr: float,
166+
dampening: float,
167+
nesterov: bool,
168+
maximize: bool,
169+
has_sparse_grad: bool
170+
):
171+
for i, param in enumerate(params):
172+
d_p = d_p_list[i] if not maximize else -d_p_list[i]
173+
174+
param.mul_(1. - lr * weight_decay)
175+
176+
if momentum != 0:
177+
buf = momentum_buffer_list[i]
178+
179+
if buf is None:
180+
buf = torch.clone(d_p).detach()
181+
momentum_buffer_list[i] = buf
182+
else:
183+
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
184+
185+
if nesterov:
186+
d_p = d_p.add(buf, alpha=momentum)
187+
else:
188+
d_p = buf
189+
190+
param.add_(d_p, alpha=-lr)
191+
192+
193+
def _multi_tensor_sgdw(
194+
params: List[Tensor],
195+
grads: List[Tensor],
196+
momentum_buffer_list: List[Optional[Tensor]],
197+
*,
198+
weight_decay: float,
199+
momentum: float,
200+
lr: float,
201+
dampening: float,
202+
nesterov: bool,
203+
maximize: bool,
204+
has_sparse_grad: bool
205+
):
206+
if len(params) == 0:
207+
return
208+
209+
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
210+
[params, grads, momentum_buffer_list], with_indices=True)
211+
for ((device_params, device_grads, device_momentum_buffer_list), indices) in grouped_tensors.values():
212+
device_has_sparse_grad = has_sparse_grad and any(grad.is_sparse for grad in device_grads)
213+
214+
if maximize:
215+
device_grads = torch._foreach_neg(device_grads)
216+
217+
torch._foreach_mul_(params, 1. - lr * weight_decay)
218+
219+
if momentum != 0:
220+
bufs = []
221+
222+
all_states_with_momentum_buffer = True
223+
for i in range(len(device_momentum_buffer_list)):
224+
if device_momentum_buffer_list[i] is None:
225+
all_states_with_momentum_buffer = False
226+
break
227+
else:
228+
bufs.append(device_momentum_buffer_list[i])
229+
230+
if all_states_with_momentum_buffer:
231+
torch._foreach_mul_(bufs, momentum)
232+
torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
233+
else:
234+
bufs = []
235+
for i in range(len(device_momentum_buffer_list)):
236+
if device_momentum_buffer_list[i] is None:
237+
buf = device_momentum_buffer_list[i] = momentum_buffer_list[indices[i]] = \
238+
torch.clone(device_grads[i]).detach()
239+
else:
240+
buf = device_momentum_buffer_list[i]
241+
buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
242+
243+
bufs.append(buf)
244+
245+
if nesterov:
246+
torch._foreach_add_(device_grads, bufs, alpha=momentum)
247+
else:
248+
device_grads = bufs
249+
250+
if not device_has_sparse_grad:
251+
torch._foreach_add_(device_params, device_grads, alpha=-lr)
252+
else:
253+
# foreach APIs don't support sparse
254+
for i in range(len(device_params)):
255+
device_params[i].add_(device_grads[i], alpha=-lr)

0 commit comments

Comments
 (0)