Skip to content

Commit a024ab3

Browse files
committed
Replace radam & nadam impl with torch.optim ver, rename legacy adamw, nadam, radam impl in timm. Update optim factory & tests.
1 parent 7b54eab commit a024ab3

File tree

6 files changed

+106
-30
lines changed

6 files changed

+106
-30
lines changed

tests/test_optim.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -376,10 +376,17 @@ def test_adam(optimizer):
376376

377377
@pytest.mark.parametrize('optimizer', ['adopt', 'adoptw'])
378378
def test_adopt(optimizer):
379-
# FIXME rosenbrock is not passing for ADOPT
380-
# _test_rosenbrock(
381-
# lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
382-
# )
379+
_test_rosenbrock(
380+
lambda params: create_optimizer_v2(params, optimizer, lr=3e-3)
381+
)
382+
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
383+
384+
385+
@pytest.mark.parametrize('optimizer', ['adan', 'adanw'])
386+
def test_adan(optimizer):
387+
_test_rosenbrock(
388+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
389+
)
383390
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
384391

385392

@@ -432,6 +439,14 @@ def test_lamb(optimizer):
432439
_test_model(optimizer, dict(lr=1e-3))
433440

434441

442+
@pytest.mark.parametrize('optimizer', ['laprop'])
443+
def test_laprop(optimizer):
444+
_test_rosenbrock(
445+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-2)
446+
)
447+
_test_model(optimizer, dict(lr=1e-2))
448+
449+
435450
@pytest.mark.parametrize('optimizer', ['lars', 'larc', 'nlars', 'nlarc'])
436451
def test_lars(optimizer):
437452
_test_rosenbrock(
@@ -448,6 +463,14 @@ def test_madgrad(optimizer):
448463
_test_model(optimizer, dict(lr=1e-2))
449464

450465

466+
@pytest.mark.parametrize('optimizer', ['mars'])
467+
def test_mars(optimizer):
468+
_test_rosenbrock(
469+
lambda params: create_optimizer_v2(params, optimizer, lr=1e-3)
470+
)
471+
_test_model(optimizer, dict(lr=5e-2), after_step=1) # note no convergence in first step for ADOPT
472+
473+
451474
@pytest.mark.parametrize('optimizer', ['novograd'])
452475
def test_novograd(optimizer):
453476
_test_rosenbrock(

timm/optim/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,22 +3,27 @@
33
from .adafactor_bv import AdafactorBigVision
44
from .adahessian import Adahessian
55
from .adamp import AdamP
6-
from .adamw import AdamW
6+
from .adamw import AdamWLegacy
77
from .adan import Adan
88
from .adopt import Adopt
99
from .lamb import Lamb
10+
from .laprop import LaProp
1011
from .lars import Lars
1112
from .lion import Lion
1213
from .lookahead import Lookahead
1314
from .madgrad import MADGRAD
14-
from .nadam import Nadam
15+
from .mars import Mars
16+
from .nadam import NAdamLegacy
1517
from .nadamw import NAdamW
1618
from .nvnovograd import NvNovoGrad
17-
from .radam import RAdam
19+
from .radam import RAdamLegacy
1820
from .rmsprop_tf import RMSpropTF
1921
from .sgdp import SGDP
2022
from .sgdw import SGDW
2123

24+
# bring torch optim into timm.optim namespace for consistency
25+
from torch.optim import Adadelta, Adagrad, Adamax, Adam, NAdam, RAdam, RMSprop, SGD
26+
2227
from ._optim_factory import list_optimizers, get_optimizer_class, get_optimizer_info, OptimInfo, OptimizerRegistry, \
2328
create_optimizer_v2, create_optimizer, optimizer_kwargs
2429
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay, auto_group_layers

timm/optim/_optim_factory.py

Lines changed: 53 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -19,17 +19,20 @@
1919
from .adafactor_bv import AdafactorBigVision
2020
from .adahessian import Adahessian
2121
from .adamp import AdamP
22+
from .adamw import AdamWLegacy
2223
from .adan import Adan
2324
from .adopt import Adopt
2425
from .lamb import Lamb
26+
from .laprop import LaProp
2527
from .lars import Lars
2628
from .lion import Lion
2729
from .lookahead import Lookahead
2830
from .madgrad import MADGRAD
29-
from .nadam import Nadam
31+
from .mars import Mars
32+
from .nadam import NAdamLegacy
3033
from .nadamw import NAdamW
3134
from .nvnovograd import NvNovoGrad
32-
from .radam import RAdam
35+
from .radam import RAdamLegacy
3336
from .rmsprop_tf import RMSpropTF
3437
from .sgdp import SGDP
3538
from .sgdw import SGDW
@@ -384,13 +387,19 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
384387
OptimInfo(
385388
name='adam',
386389
opt_class=optim.Adam,
387-
description='torch.optim Adam (Adaptive Moment Estimation)',
390+
description='torch.optim.Adam, Adaptive Moment Estimation',
388391
has_betas=True
389392
),
390393
OptimInfo(
391394
name='adamw',
392395
opt_class=optim.AdamW,
393-
description='torch.optim Adam with decoupled weight decay regularization',
396+
description='torch.optim.AdamW, Adam with decoupled weight decay',
397+
has_betas=True
398+
),
399+
OptimInfo(
400+
name='adamwlegacy',
401+
opt_class=AdamWLegacy,
402+
description='legacy impl of AdamW that pre-dates inclusion to torch.optim',
394403
has_betas=True
395404
),
396405
OptimInfo(
@@ -402,26 +411,45 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
402411
),
403412
OptimInfo(
404413
name='nadam',
405-
opt_class=Nadam,
406-
description='Adam with Nesterov momentum',
414+
opt_class=torch.optim.NAdam,
415+
description='torch.optim.NAdam, Adam with Nesterov momentum',
416+
has_betas=True
417+
),
418+
OptimInfo(
419+
name='nadamlegacy',
420+
opt_class=NAdamLegacy,
421+
description='legacy impl of NAdam that pre-dates inclusion in torch.optim',
407422
has_betas=True
408423
),
409424
OptimInfo(
410425
name='nadamw',
411426
opt_class=NAdamW,
412-
description='Adam with Nesterov momentum and decoupled weight decay',
427+
description='Adam with Nesterov momentum and decoupled weight decay, mlcommons/algorithmic-efficiency impl',
413428
has_betas=True
414429
),
415430
OptimInfo(
416431
name='radam',
417-
opt_class=RAdam,
418-
description='Rectified Adam with variance adaptation',
432+
opt_class=torch.optim.RAdam,
433+
description='torch.optim.RAdam, Rectified Adam with variance adaptation',
434+
has_betas=True
435+
),
436+
OptimInfo(
437+
name='radamlegacy',
438+
opt_class=RAdamLegacy,
439+
description='legacy impl of RAdam that predates inclusion in torch.optim',
419440
has_betas=True
420441
),
442+
OptimInfo(
443+
name='radamw',
444+
opt_class=torch.optim.RAdam,
445+
description='torch.optim.RAdamW, Rectified Adam with variance adaptation and decoupled weight decay',
446+
has_betas=True,
447+
defaults={'decoupled_weight_decay': True}
448+
),
421449
OptimInfo(
422450
name='adamax',
423451
opt_class=optim.Adamax,
424-
description='torch.optim Adamax, Adam with infinity norm for more stable updates',
452+
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
425453
has_betas=True
426454
),
427455
OptimInfo(
@@ -518,12 +546,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
518546
OptimInfo(
519547
name='adadelta',
520548
opt_class=optim.Adadelta,
521-
description='torch.optim Adadelta, Adapts learning rates based on running windows of gradients'
549+
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
522550
),
523551
OptimInfo(
524552
name='adagrad',
525553
opt_class=optim.Adagrad,
526-
description='torch.optim Adagrad, Adapts learning rates using cumulative squared gradients',
554+
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
527555
defaults={'eps': 1e-8}
528556
),
529557
OptimInfo(
@@ -549,6 +577,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
549577
has_betas=True,
550578
second_order=True,
551579
),
580+
OptimInfo(
581+
name='laprop',
582+
opt_class=LaProp,
583+
description='Separating Momentum and Adaptivity in Adam',
584+
has_betas=True,
585+
),
552586
OptimInfo(
553587
name='lion',
554588
opt_class=Lion,
@@ -569,6 +603,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
569603
has_momentum=True,
570604
defaults={'decoupled_decay': True}
571605
),
606+
OptimInfo(
607+
name='mars',
608+
opt_class=Mars,
609+
description='Unleashing the Power of Variance Reduction for Training Large Models',
610+
has_betas=True,
611+
),
572612
OptimInfo(
573613
name='novograd',
574614
opt_class=NvNovoGrad,
@@ -578,7 +618,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
578618
OptimInfo(
579619
name='rmsprop',
580620
opt_class=optim.RMSprop,
581-
description='torch.optim RMSprop, Root Mean Square Propagation',
621+
description='torch.optim.RMSprop, Root Mean Square Propagation',
582622
has_momentum=True,
583623
defaults={'alpha': 0.9}
584624
),

timm/optim/adamw.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
11
""" AdamW Optimizer
22
Impl copied from PyTorch master
33
4-
NOTE: Builtin optim.AdamW is used by the factory, this impl only serves as a Python based reference, will be removed
5-
someday
4+
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
65
"""
76
import math
87
import torch
98
from torch.optim.optimizer import Optimizer
109

1110

12-
class AdamW(Optimizer):
11+
class AdamWLegacy(Optimizer):
1312
r"""Implements AdamW algorithm.
1413
14+
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
15+
1516
The original Adam algorithm was proposed in `Adam: A Method for Stochastic Optimization`_.
1617
The AdamW variant was proposed in `Decoupled Weight Decay Regularization`_.
1718
@@ -61,10 +62,10 @@ def __init__(
6162
weight_decay=weight_decay,
6263
amsgrad=amsgrad,
6364
)
64-
super(AdamW, self).__init__(params, defaults)
65+
super(AdamWLegacy, self).__init__(params, defaults)
6566

6667
def __setstate__(self, state):
67-
super(AdamW, self).__setstate__(state)
68+
super(AdamWLegacy, self).__setstate__(state)
6869
for group in self.param_groups:
6970
group.setdefault('amsgrad', False)
7071

timm/optim/nadam.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,11 @@
44
from torch.optim.optimizer import Optimizer
55

66

7-
class Nadam(Optimizer):
7+
class NAdamLegacy(Optimizer):
88
"""Implements Nadam algorithm (a variant of Adam based on Nesterov momentum).
99
10+
NOTE: This impl has been deprecated in favour of torch.optim.NAdam and remains as a reference
11+
1012
It has been proposed in `Incorporating Nesterov Momentum into Adam`__.
1113
1214
Arguments:
@@ -45,7 +47,7 @@ def __init__(
4547
weight_decay=weight_decay,
4648
schedule_decay=schedule_decay,
4749
)
48-
super(Nadam, self).__init__(params, defaults)
50+
super(NAdamLegacy, self).__init__(params, defaults)
4951

5052
@torch.no_grad()
5153
def step(self, closure=None):

timm/optim/radam.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,19 @@
11
"""RAdam Optimizer.
22
Implementation lifted from: https://github.com/LiyuanLucasLiu/RAdam
33
Paper: `On the Variance of the Adaptive Learning Rate and Beyond` - https://arxiv.org/abs/1908.03265
4+
5+
NOTE: This impl has been deprecated in favour of torch.optim.RAdam and remains as a reference
46
"""
57
import math
68
import torch
79
from torch.optim.optimizer import Optimizer
810

911

10-
class RAdam(Optimizer):
12+
class RAdamLegacy(Optimizer):
13+
""" PyTorch RAdam optimizer
1114
15+
NOTE: This impl has been deprecated in favour of torch.optim.AdamW and remains as a reference
16+
"""
1217
def __init__(
1318
self,
1419
params,
@@ -24,10 +29,10 @@ def __init__(
2429
weight_decay=weight_decay,
2530
buffer=[[None, None, None] for _ in range(10)]
2631
)
27-
super(RAdam, self).__init__(params, defaults)
32+
super(RAdamLegacy, self).__init__(params, defaults)
2833

2934
def __setstate__(self, state):
30-
super(RAdam, self).__setstate__(state)
35+
super(RAdamLegacy, self).__setstate__(state)
3136

3237
@torch.no_grad()
3338
def step(self, closure=None):

0 commit comments

Comments
 (0)