Skip to content

Commit 514b093

Browse files
committed
Experimenting with per-epoch learning rate noise w/ step scheduler
1 parent d77f45a commit 514b093

File tree

3 files changed

+29
-3
lines changed

3 files changed

+29
-3
lines changed

timm/scheduler/scheduler_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,18 @@ def create_scheduler(args, optimizer):
3333
)
3434
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
3535
elif args.sched == 'step':
36+
if isinstance(args.lr_noise, (list, tuple)):
37+
noise_range = [n * num_epochs for n in args.lr_noise]
38+
else:
39+
noise_range = args.lr_noise * num_epochs
40+
print(noise_range)
3641
lr_scheduler = StepLRScheduler(
3742
optimizer,
3843
decay_t=args.decay_epochs,
3944
decay_rate=args.decay_rate,
4045
warmup_lr_init=args.warmup_lr,
4146
warmup_t=args.warmup_epochs,
47+
noise_range_t=noise_range,
48+
noise_std=args.lr_noise_std,
4249
)
4350
return lr_scheduler, num_epochs

timm/scheduler/step_lr.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@ def __init__(self,
1414
decay_rate: float = 1.,
1515
warmup_t=0,
1616
warmup_lr_init=0,
17+
noise_range_t=None,
18+
noise_std=1.0,
1719
t_in_epochs=True,
18-
initialize=True) -> None:
20+
initialize=True,
21+
) -> None:
1922
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
2023

2124
self.decay_t = decay_t
2225
self.decay_rate = decay_rate
2326
self.warmup_t = warmup_t
2427
self.warmup_lr_init = warmup_lr_init
28+
self.noise_range_t = noise_range_t
29+
self.noise_std = noise_std
2530
self.t_in_epochs = t_in_epochs
2631
if self.warmup_t:
2732
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
@@ -33,8 +38,18 @@ def _get_lr(self, t):
3338
if t < self.warmup_t:
3439
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
3540
else:
36-
lrs = [v * (self.decay_rate ** (t // self.decay_t))
37-
for v in self.base_values]
41+
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
42+
if self.noise_range_t is not None:
43+
if isinstance(self.noise_range_t, (list, tuple)):
44+
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
45+
else:
46+
apply_noise = t >= self.noise_range_t
47+
if apply_noise:
48+
g = torch.Generator()
49+
g.manual_seed(t)
50+
lr_mult = torch.randn(1, generator=g).item() * self.noise_std + 1.
51+
lrs = [min(5 * v, max(v / 5, v * lr_mult)) for v in lrs]
52+
print(lrs)
3853
return lrs
3954

4055
def get_epoch_values(self, epoch: int):

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@
105105
help='LR scheduler (default: "step"')
106106
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
107107
help='learning rate (default: 0.01)')
108+
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
109+
help='learning rate noise on/off epoch percentages')
110+
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
111+
help='learning rate nose std-dev (default: 1.0)')
108112
parser.add_argument('--warmup-lr', type=float, default=0.0001, metavar='LR',
109113
help='warmup learning rate (default: 0.0001)')
110114
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',

0 commit comments

Comments
 (0)