Skip to content

Commit e5aea35

Browse files
committed
Update Adopt to include clipping for stability, separate wd so no param decay if update not taken on first step
1 parent 444c506 commit e5aea35

File tree

1 file changed

+49
-64
lines changed

1 file changed

+49
-64
lines changed

timm/optim/adopt.py

Lines changed: 49 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
1414
"""
1515

16-
from typing import cast, List, Optional, Tuple, Union
16+
from typing import cast, Callable, List, Optional, Tuple, Union
1717

1818
import torch
1919
from torch import Tensor
@@ -64,6 +64,7 @@ def __init__(
6464
lr: Union[float, Tensor] = 1e-3,
6565
betas: Tuple[float, float] = (0.9, 0.9999),
6666
eps: float = 1e-6,
67+
clip_exp: Optional[float] = 0.333,
6768
weight_decay: float = 0.0,
6869
decoupled: bool = False,
6970
*,
@@ -95,6 +96,7 @@ def __init__(
9596
betas=betas,
9697
eps=eps,
9798
weight_decay=weight_decay,
99+
clip_exp=clip_exp,
98100
decoupled=decoupled,
99101
maximize=maximize,
100102
foreach=foreach,
@@ -111,6 +113,7 @@ def __setstate__(self, state):
111113
group.setdefault("foreach", None)
112114
group.setdefault("capturable", False)
113115
group.setdefault("differentiable", False)
116+
group.setdefault("clip_exp", None)
114117
for p in group["params"]:
115118
p_state = self.state.get(p, [])
116119
if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
@@ -141,9 +144,7 @@ def _init_group(
141144
has_complex |= torch.is_complex(p)
142145
params_with_grad.append(p)
143146
if p.grad.is_sparse:
144-
raise RuntimeError(
145-
"ADOPT does not support sparse gradients"
146-
)
147+
raise RuntimeError("ADOPT does not support sparse gradients")
147148
grads.append(p.grad)
148149

149150
state = self.state[p]
@@ -153,36 +154,24 @@ def _init_group(
153154
# Deliberately host `step` on CPU if both capturable and fused are off.
154155
# This is because kernel launches are costly on CUDA and XLA.
155156
state["step"] = (
156-
torch.zeros(
157-
(),
158-
dtype=_get_scalar_dtype(),
159-
device=p.grad.device,
160-
)
157+
torch.zeros((), dtype=_get_scalar_dtype(), device=p.grad.device)
161158
if group["capturable"]
162159
else torch.tensor(0.0, dtype=_get_scalar_dtype())
163160
)
164161
# Exponential moving average of gradient values
165-
state["exp_avg"] = torch.zeros_like(
166-
p.grad, memory_format=torch.preserve_format
167-
)
162+
state["exp_avg"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
168163
# Exponential moving average of squared gradient values
169-
state["exp_avg_sq"] = torch.zeros_like(
170-
p.grad, memory_format=torch.preserve_format
171-
)
164+
state["exp_avg_sq"] = torch.zeros_like(p.grad, memory_format=torch.preserve_format)
172165

173166
exp_avgs.append(state["exp_avg"])
174167
exp_avg_sqs.append(state["exp_avg_sq"])
175168

176169
if group["differentiable"] and state["step"].requires_grad:
177-
raise RuntimeError(
178-
"`requires_grad` is not supported for `step` in differentiable mode"
179-
)
170+
raise RuntimeError("`requires_grad` is not supported for `step` in differentiable mode")
180171

181172
# Foreach without capturable does not support a tensor lr
182173
if group["foreach"] and torch.is_tensor(group["lr"]) and not group["capturable"]:
183-
raise RuntimeError(
184-
"lr as a Tensor is not supported for capturable=False and foreach=True"
185-
)
174+
raise RuntimeError("lr as a Tensor is not supported for capturable=False and foreach=True")
186175

187176
state_steps.append(state["step"])
188177
return has_complex
@@ -231,6 +220,7 @@ def step(self, closure=None):
231220
beta2=beta2,
232221
lr=group["lr"],
233222
weight_decay=group["weight_decay"],
223+
clip_exp=group["clip_exp"],
234224
decoupled=group["decoupled"],
235225
eps=group["eps"],
236226
maximize=group["maximize"],
@@ -258,6 +248,7 @@ def _single_tensor_adopt(
258248
beta2: float,
259249
lr: Union[float, Tensor],
260250
weight_decay: float,
251+
clip_exp: Optional[float],
261252
decoupled: bool,
262253
eps: float,
263254
maximize: bool,
@@ -282,20 +273,12 @@ def _single_tensor_adopt(
282273
if capturable and not _is_compiling():
283274
from torch.optim.optimizer import _get_capturable_supported_devices
284275
capturable_supported_devices = _get_capturable_supported_devices()
285-
assert (
286-
param.device.type == step_t.device.type
287-
and param.device.type in capturable_supported_devices
288-
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
276+
assert param.device.type == step_t.device.type and param.device.type in capturable_supported_devices,\
277+
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
289278

290279
# update step
291280
step_t += 1
292281

293-
if weight_decay != 0:
294-
if decoupled:
295-
param.add_(param, alpha=-lr * weight_decay)
296-
else:
297-
grad = grad.add(param, alpha=weight_decay)
298-
299282
if torch.is_complex(param):
300283
grad = torch.view_as_real(grad)
301284
if exp_avg is not None:
@@ -304,17 +287,25 @@ def _single_tensor_adopt(
304287
exp_avg_sq = torch.view_as_real(exp_avg_sq)
305288
param = torch.view_as_real(param)
306289

290+
if weight_decay != 0 and not decoupled:
291+
grad = grad.add(param, alpha=weight_decay)
292+
307293
step = step_t if capturable or differentiable else _get_value(step_t)
308294
if step == 1:
309295
exp_avg_sq.addcmul_(grad, grad.conj())
310296
continue
311297

298+
if weight_decay != 0 and decoupled:
299+
param.add_(param, alpha=-lr * weight_decay)
300+
312301
denom = torch.clamp(exp_avg_sq.sqrt(), eps)
313-
if step == 2:
314-
exp_avg.addcdiv_(grad, denom)
315-
else:
316-
exp_avg.mul_(beta1).addcdiv_(grad, denom, value=1 - beta1)
302+
normed_grad = grad.div(denom)
303+
304+
if clip_exp is not None:
305+
clip_val = (step - 1) ** clip_exp
306+
normed_grad.clamp_(-clip_val, clip_val)
317307

308+
exp_avg.lerp_(normed_grad, 1 - beta1)
318309
param.add_(exp_avg, alpha=-lr)
319310

320311
exp_avg_sq.mul_(beta2).addcmul_(grad, grad.conj(), value=1 - beta2)
@@ -334,6 +325,7 @@ def _multi_tensor_adopt(
334325
beta2: float,
335326
lr: Union[float, Tensor],
336327
weight_decay: float,
328+
clip_exp: Optional[float],
337329
decoupled: bool,
338330
eps: float,
339331
maximize: bool,
@@ -355,8 +347,7 @@ def _multi_tensor_adopt(
355347
supports_xla=False
356348
)
357349
assert all(
358-
p.device.type == step.device.type
359-
and p.device.type in capturable_supported_devices
350+
p.device.type == step.device.type and p.device.type in capturable_supported_devices
360351
for p, step in zip(params, state_steps)
361352
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
362353

@@ -382,9 +373,7 @@ def _multi_tensor_adopt(
382373

383374
# Handle complex parameters
384375
if has_complex:
385-
_view_as_real(
386-
device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
387-
)
376+
_view_as_real(device_params, device_grads, device_exp_avgs, device_exp_avg_sqs)
388377

389378
if maximize:
390379
device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
@@ -394,44 +383,38 @@ def _multi_tensor_adopt(
394383
# and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
395384
# wrapped it once now. The alpha is required to assure we go to the right overload.
396385
if not _is_compiling() and device_state_steps[0].is_cpu:
397-
torch._foreach_add_(
398-
device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
399-
)
386+
torch._foreach_add_(device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0)
400387
else:
401388
torch._foreach_add_(device_state_steps, 1)
402389

403-
if weight_decay != 0:
404-
if decoupled:
405-
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
390+
if weight_decay != 0 and not decoupled:
391+
# Re-use the intermediate memory (device_grads) already allocated for maximize
392+
if maximize:
393+
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
406394
else:
407-
# Re-use the intermediate memory (device_grads) already allocated for maximize
408-
if maximize:
409-
torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
410-
else:
411-
device_grads = torch._foreach_add( # type: ignore[assignment]
412-
device_grads, device_params, alpha=weight_decay
413-
)
395+
device_grads = torch._foreach_add(device_grads, device_params, alpha=weight_decay)
414396

415397
if device_state_steps[0] == 1:
416398
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads)
417399
continue
418400

401+
if weight_decay != 0 and decoupled:
402+
torch._foreach_add_(device_params, device_params, alpha=-lr * weight_decay)
403+
419404
exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
420-
exp_avg_sq_sqrt = torch._foreach_maximum(exp_avg_sq_sqrt, eps)
405+
torch._foreach_maximum_(exp_avg_sq_sqrt, eps)
406+
normed_grad = torch._foreach_div(device_grads, exp_avg_sq_sqrt)
421407

422-
if device_state_steps[0] == 2:
423-
torch._foreach_addcdiv_(device_exp_avgs, device_grads, exp_avg_sq_sqrt)
424-
else:
425-
torch._foreach_mul_(device_exp_avgs, beta1)
426-
torch._foreach_addcdiv_(
427-
device_exp_avgs, device_grads, exp_avg_sq_sqrt, value=1 - beta1
428-
)
408+
if clip_exp is not None:
409+
clip_val = (device_state_steps[0] - 1) ** clip_exp
410+
torch._foreach_maximum_(normed_grad, -clip_val)
411+
torch._foreach_minimum_(normed_grad, clip_val)
429412

413+
torch._foreach_lerp_(device_exp_avgs, normed_grad, 1 - beta1)
430414
torch._foreach_add_(device_params, device_exp_avgs, alpha=-lr)
415+
431416
torch._foreach_mul_(device_exp_avg_sqs, beta2)
432-
torch._foreach_addcmul_(
433-
device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2
434-
)
417+
torch._foreach_addcmul_(device_exp_avg_sqs, device_grads, device_grads, value=1 - beta2)
435418

436419

437420
#@_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adopt) # FIXME internal context mgr, can't use
@@ -454,6 +437,7 @@ def adopt(
454437
beta2: float,
455438
lr: Union[float, Tensor],
456439
weight_decay: float,
440+
clip_exp: Optional[float],
457441
decoupled: bool,
458442
eps: float,
459443
maximize: bool,
@@ -490,6 +474,7 @@ def adopt(
490474
beta2=beta2,
491475
lr=lr,
492476
weight_decay=weight_decay,
477+
clip_exp=clip_exp,
493478
decoupled=decoupled,
494479
eps=eps,
495480
maximize=maximize,

0 commit comments

Comments
 (0)