Skip to content

Commit 27b3680

Browse files
committed
Revamp LR noise, move logic to scheduler base. Fixup PlateauLRScheduler and add it as an option.
1 parent 514b093 commit 27b3680

File tree

7 files changed

+114
-53
lines changed

7 files changed

+114
-53
lines changed

timm/scheduler/cosine_lr.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,15 @@ def __init__(self,
2929
warmup_prefix=False,
3030
cycle_limit=0,
3131
t_in_epochs=True,
32+
noise_range_t=None,
33+
noise_pct=0.67,
34+
noise_std=1.0,
35+
noise_seed=42,
3236
initialize=True) -> None:
33-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
37+
super().__init__(
38+
optimizer, param_group_field="lr",
39+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
40+
initialize=initialize)
3441

3542
assert t_initial > 0
3643
assert lr_min >= 0

timm/scheduler/plateau_lr.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler):
88

99
def __init__(self,
1010
optimizer,
11-
factor=0.1,
12-
patience=10,
13-
verbose=False,
11+
decay_rate=0.1,
12+
patience_t=10,
13+
verbose=True,
1414
threshold=1e-4,
15-
cooldown_epochs=0,
16-
warmup_updates=0,
15+
cooldown_t=0,
16+
warmup_t=0,
1717
warmup_lr_init=0,
1818
lr_min=0,
19+
mode='min',
20+
initialize=True,
1921
):
20-
super().__init__(optimizer, 'lr', initialize=False)
22+
super().__init__(optimizer, 'lr', initialize=initialize)
2123

2224
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
23-
self.optimizer.optimizer,
24-
patience=patience,
25-
factor=factor,
25+
self.optimizer,
26+
patience=patience_t,
27+
factor=decay_rate,
2628
verbose=verbose,
2729
threshold=threshold,
28-
cooldown=cooldown_epochs,
30+
cooldown=cooldown_t,
31+
mode=mode,
2932
min_lr=lr_min
3033
)
3134

32-
self.warmup_updates = warmup_updates
35+
self.warmup_t = warmup_t
3336
self.warmup_lr_init = warmup_lr_init
34-
35-
if self.warmup_updates:
36-
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
37-
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
37+
if self.warmup_t:
38+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
3839
super().update_groups(self.warmup_lr_init)
3940
else:
4041
self.warmup_steps = [1 for _ in self.base_values]
@@ -51,18 +52,9 @@ def load_state_dict(self, state_dict):
5152
self.lr_scheduler.last_epoch = state_dict['last_epoch']
5253

5354
# override the base class step fn completely
54-
def step(self, epoch, val_loss=None):
55-
"""Update the learning rate at the end of the given epoch."""
56-
if val_loss is not None and not self.warmup_active:
57-
self.lr_scheduler.step(val_loss, epoch)
58-
else:
59-
self.lr_scheduler.last_epoch = epoch
60-
61-
def get_update_values(self, num_updates: int):
62-
if num_updates < self.warmup_updates:
63-
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
55+
def step(self, epoch, metric=None):
56+
if epoch <= self.warmup_t:
57+
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
58+
super().update_groups(lrs)
6459
else:
65-
self.warmup_active = False # warmup cancelled by first update past warmup_update count
66-
lrs = None # no change on update after warmup stage
67-
return lrs
68-
60+
self.lr_scheduler.step(metric, epoch)

timm/scheduler/scheduler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class Scheduler:
2525
def __init__(self,
2626
optimizer: torch.optim.Optimizer,
2727
param_group_field: str,
28+
noise_range_t=None,
29+
noise_type='normal',
30+
noise_pct=0.67,
31+
noise_std=1.0,
32+
noise_seed=None,
2833
initialize: bool = True) -> None:
2934
self.optimizer = optimizer
3035
self.param_group_field = param_group_field
@@ -40,6 +45,11 @@ def __init__(self,
4045
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
4146
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
4247
self.metric = None # any point to having this for all?
48+
self.noise_range_t = noise_range_t
49+
self.noise_pct = noise_pct
50+
self.noise_type = noise_type
51+
self.noise_std = noise_std
52+
self.noise_seed = noise_seed if noise_seed is not None else 42
4353
self.update_groups(self.base_values)
4454

4555
def state_dict(self) -> Dict[str, Any]:
@@ -58,16 +68,38 @@ def step(self, epoch: int, metric: float = None) -> None:
5868
self.metric = metric
5969
values = self.get_epoch_values(epoch)
6070
if values is not None:
71+
values = self._add_noise(values, epoch)
6172
self.update_groups(values)
6273

6374
def step_update(self, num_updates: int, metric: float = None):
6475
self.metric = metric
6576
values = self.get_update_values(num_updates)
6677
if values is not None:
78+
values = self._add_noise(values, num_updates)
6779
self.update_groups(values)
6880

6981
def update_groups(self, values):
7082
if not isinstance(values, (list, tuple)):
7183
values = [values] * len(self.optimizer.param_groups)
7284
for param_group, value in zip(self.optimizer.param_groups, values):
7385
param_group[self.param_group_field] = value
86+
87+
def _add_noise(self, lrs, t):
88+
if self.noise_range_t is not None:
89+
if isinstance(self.noise_range_t, (list, tuple)):
90+
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91+
else:
92+
apply_noise = t >= self.noise_range_t
93+
if apply_noise:
94+
g = torch.Generator()
95+
g.manual_seed(self.noise_seed + t)
96+
if self.noise_type == 'normal':
97+
while True:
98+
# resample if noise out of percent limit, brute force but shouldn't spin much
99+
noise = torch.randn(1, generator=g).item()
100+
if abs(noise) < self.noise_pct:
101+
break
102+
else:
103+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104+
lrs = [v + v * noise for v in lrs]
105+
return lrs

timm/scheduler/scheduler_factory.py

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
from .cosine_lr import CosineLRScheduler
22
from .tanh_lr import TanhLRScheduler
33
from .step_lr import StepLRScheduler
4+
from .plateau_lr import PlateauLRScheduler
45

56

67
def create_scheduler(args, optimizer):
78
num_epochs = args.epochs
9+
10+
if args.lr_noise is not None:
11+
if isinstance(args.lr_noise, (list, tuple)):
12+
noise_range = [n * num_epochs for n in args.lr_noise]
13+
else:
14+
noise_range = args.lr_noise * num_epochs
15+
print('Noise range:', noise_range)
16+
else:
17+
noise_range = None
18+
819
lr_scheduler = None
920
#FIXME expose cycle parms of the scheduler config to arguments
1021
if args.sched == 'cosine':
@@ -18,6 +29,10 @@ def create_scheduler(args, optimizer):
1829
warmup_t=args.warmup_epochs,
1930
cycle_limit=1,
2031
t_in_epochs=True,
32+
noise_range_t=noise_range,
33+
noise_pct=args.lr_noise_pct,
34+
noise_std=args.lr_noise_std,
35+
noise_seed=args.seed,
2136
)
2237
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
2338
elif args.sched == 'tanh':
@@ -30,21 +45,33 @@ def create_scheduler(args, optimizer):
3045
warmup_t=args.warmup_epochs,
3146
cycle_limit=1,
3247
t_in_epochs=True,
48+
noise_range_t=noise_range,
49+
noise_pct=args.lr_noise_pct,
50+
noise_std=args.lr_noise_std,
51+
noise_seed=args.seed,
3352
)
3453
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
3554
elif args.sched == 'step':
36-
if isinstance(args.lr_noise, (list, tuple)):
37-
noise_range = [n * num_epochs for n in args.lr_noise]
38-
else:
39-
noise_range = args.lr_noise * num_epochs
40-
print(noise_range)
4155
lr_scheduler = StepLRScheduler(
4256
optimizer,
4357
decay_t=args.decay_epochs,
4458
decay_rate=args.decay_rate,
4559
warmup_lr_init=args.warmup_lr,
4660
warmup_t=args.warmup_epochs,
4761
noise_range_t=noise_range,
62+
noise_pct=args.lr_noise_pct,
4863
noise_std=args.lr_noise_std,
64+
noise_seed=args.seed,
65+
)
66+
elif args.sched == 'plateau':
67+
lr_scheduler = PlateauLRScheduler(
68+
optimizer,
69+
decay_rate=args.decay_rate,
70+
patience_t=args.patience_epochs,
71+
lr_min=args.min_lr,
72+
warmup_lr_init=args.warmup_lr,
73+
warmup_t=args.warmup_epochs,
74+
cooldown_t=args.cooldown_epochs,
4975
)
76+
5077
return lr_scheduler, num_epochs

timm/scheduler/step_lr.py

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,26 @@ class StepLRScheduler(Scheduler):
1010

1111
def __init__(self,
1212
optimizer: torch.optim.Optimizer,
13-
decay_t: int,
13+
decay_t: float,
1414
decay_rate: float = 1.,
1515
warmup_t=0,
1616
warmup_lr_init=0,
17+
t_in_epochs=True,
1718
noise_range_t=None,
19+
noise_pct=0.67,
1820
noise_std=1.0,
19-
t_in_epochs=True,
21+
noise_seed=42,
2022
initialize=True,
2123
) -> None:
22-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
24+
super().__init__(
25+
optimizer, param_group_field="lr",
26+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
27+
initialize=initialize)
2328

2429
self.decay_t = decay_t
2530
self.decay_rate = decay_rate
2631
self.warmup_t = warmup_t
2732
self.warmup_lr_init = warmup_lr_init
28-
self.noise_range_t = noise_range_t
29-
self.noise_std = noise_std
3033
self.t_in_epochs = t_in_epochs
3134
if self.warmup_t:
3235
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@@ -39,17 +42,6 @@ def _get_lr(self, t):
3942
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
4043
else:
4144
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
42-
if self.noise_range_t is not None:
43-
if isinstance(self.noise_range_t, (list, tuple)):
44-
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
45-
else:
46-
apply_noise = t >= self.noise_range_t
47-
if apply_noise:
48-
g = torch.Generator()
49-
g.manual_seed(t)
50-
lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1.
51-
lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs]
52-
print(lrs)
5345
return lrs
5446

5547
def get_epoch_values(self, epoch: int):

timm/scheduler/tanh_lr.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@ def __init__(self,
2828
warmup_prefix=False,
2929
cycle_limit=0,
3030
t_in_epochs=True,
31+
noise_range_t=None,
32+
noise_pct=0.67,
33+
noise_std=1.0,
34+
noise_seed=42,
3135
initialize=True) -> None:
32-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
36+
super().__init__(
37+
optimizer, param_group_field="lr",
38+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
39+
initialize=initialize)
3340

3441
assert t_initial > 0
3542
assert lr_min >= 0

train.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,10 @@
107107
help='learning rate (default: 0.01)')
108108
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
109109
help='learning rate noise on/off epoch percentages')
110+
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
111+
help='learning rate noise limit percent (default: 0.67)')
110112
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
111-
help='learning rate nose std-dev (default: 1.0)')
113+
help='learning rate noise std-dev (default: 1.0)')
112114
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
113115
help='warmup learning rate (default: 0.0001)')
114116
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
@@ -123,6 +125,8 @@
123125
help='epochs to warmup LR, if scheduler supports')
124126
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
125127
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
128+
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
129+
help='patience epochs for Plateau LR scheduler (default: 10')
126130
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
127131
help='LR decay rate (default: 0.1)')
128132
# Augmentation parameters

0 commit comments

Comments
 (0)