Skip to content

Commit 56e2ac3

Browse files
authored
Merge pull request #94 from rwightman/lr_noise
Learning rate noise, MobileNetV3 weights, and activate MobileNetV3/EfficientNet weight init change
2 parents c60069c + c16f25c commit 56e2ac3

File tree

11 files changed

+145
-43
lines changed

11 files changed

+145
-43
lines changed

README.md

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22

33
## What's New
44

5+
### Feb 29, 2020
6+
* New MobileNet-V3 Large weights trained from stratch with this code to 75.77% top-1
7+
* IMPORTANT CHANGE - default weight init changed for all MobilenetV3 / EfficientNet / related models
8+
* overall results similar to a bit better training from scratch on a few smaller models tried
9+
* performance early in training seems consistently improved but less difference by end
10+
* set `fix_group_fanout=False` in `_init_weight_goog` fn if you need to reproducte past behaviour
11+
* Experimental LR noise feature added applies a random perturbation to LR each epoch in specified range of training
12+
513
### Feb 18, 2020
614
* Big refactor of model layers and addition of several attention mechanisms. Several additions motivated by 'Compounding the Performance Improvements...' (https://arxiv.org/abs/2001.06268):
715
* Move layer/module impl into `layers` subfolder/module of `models` and organize in a more granular fashion
@@ -187,7 +195,8 @@ I've leveraged the training scripts in this repository to train a few of the mod
187195
| skresnet34 | 76.912 (23.088) | 93.322 (6.678) | 22.2M | bicubic | 224 |
188196
| resnet26d | 76.68 (23.32) | 93.166 (6.834) | 16M | bicubic | 224 |
189197
| mixnet_s | 75.988 (24.012) | 92.794 (7.206) | 4.13M | bicubic | 224 |
190-
| mobilenetv3_100 | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
198+
| mobilenetv3_large_100 | 75.766 (24.234) | 92.542 (7.458) | 5.5M | bicubic | 224 |
199+
| mobilenetv3_rw | 75.634 (24.366) | 92.708 (7.292) | 5.5M | bicubic | 224 |
191200
| mnasnet_a1 | 75.448 (24.552) | 92.604 (7.396) | 3.89M | bicubic | 224 |
192201
| resnet26 | 75.292 (24.708) | 92.57 (7.43) | 16M | bicubic | 224 |
193202
| fbnetc_100 | 75.124 (24.876) | 92.386 (7.614) | 5.6M | bilinear | 224 |
@@ -361,6 +370,11 @@ Trained by [Andrew Lavin](https://github.com/andravin) with 8 V100 cards. Model
361370

362371
`./distributed_train.sh 8 /imagenet --model efficientnet_es -b 128 --sched step --epochs 450 --decay-epochs 2.4 --decay-rate .97 --opt rmsproptf --opt-eps .001 -j 8 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064`
363372

373+
### MobileNetV3-Large-100 - 75.766 top-1, 92,542 top-5
374+
375+
`./distributed_train.sh 2 /imagenet/ --model mobilenetv3_large_100 -b 512 --sched step --epochs 600 --decay-epochs 2.4 --decay-rate .973 --opt rmsproptf --opt-eps .001 -j 7 --warmup-lr 1e-6 --weight-decay 1e-5 --drop 0.2 --drop-connect 0.2 --model-ema --model-ema-decay 0.9999 --aa rand-m9-mstd0.5 --remode pixel --reprob 0.2 --amp --lr .064 --lr-noise 0.42 0.9`
376+
377+
364378
**TODO dig up some more**
365379

366380

sotabench.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def _entry(model_name, paper_model_name, paper_arxiv_id, batch_size=BATCH_SIZE,
9393
_entry('semnasnet_100', 'MnasNet-A1', '1807.11626'),
9494
_entry('spnasnet_100', 'Single-Path NAS', '1904.02877',
9595
model_desc='Trained in PyTorch with SGD, cosine LR decay'),
96-
_entry('mobilenetv3_rw', 'MobileNet V3-Large 1.0', '1905.02244',
96+
_entry('mobilenetv3_large_100', 'MobileNet V3-Large 1.0', '1905.02244',
9797
model_desc='Trained in PyTorch with RMSProp, exponential LR decay, and hyper-params matching '
9898
'paper as closely as possible.'),
9999

timm/models/efficientnet_builder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,13 @@ def __call__(self, in_chs, model_block_args):
359359
return stages
360360

361361

362-
def _init_weight_goog(m, n='', fix_group_fanout=False):
362+
def _init_weight_goog(m, n='', fix_group_fanout=True):
363363
""" Weight initialization as per Tensorflow official implementations.
364364
365365
Args:
366366
m (nn.Module): module to init
367367
n (str): module name
368-
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
369-
370-
FIXME change fix_group_fanout to default to True if experiments show better training results
368+
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
371369
372370
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
373371
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py

timm/models/mobilenetv3.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,9 @@ def _cfg(url='', **kwargs):
3131

3232
default_cfgs = {
3333
'mobilenetv3_large_075': _cfg(url=''),
34-
'mobilenetv3_large_100': _cfg(url=''),
34+
'mobilenetv3_large_100': _cfg(
35+
interpolation='bicubic',
36+
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/mobilenetv3_large_100_ra-f55367f5.pth'),
3537
'mobilenetv3_small_075': _cfg(url=''),
3638
'mobilenetv3_small_100': _cfg(url=''),
3739
'mobilenetv3_rw': _cfg(

timm/scheduler/cosine_lr.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,15 @@ def __init__(self,
2929
warmup_prefix=False,
3030
cycle_limit=0,
3131
t_in_epochs=True,
32+
noise_range_t=None,
33+
noise_pct=0.67,
34+
noise_std=1.0,
35+
noise_seed=42,
3236
initialize=True) -> None:
33-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
37+
super().__init__(
38+
optimizer, param_group_field="lr",
39+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
40+
initialize=initialize)
3441

3542
assert t_initial > 0
3643
assert lr_min >= 0

timm/scheduler/plateau_lr.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -8,33 +8,34 @@ class PlateauLRScheduler(Scheduler):
88

99
def __init__(self,
1010
optimizer,
11-
factor=0.1,
12-
patience=10,
13-
verbose=False,
11+
decay_rate=0.1,
12+
patience_t=10,
13+
verbose=True,
1414
threshold=1e-4,
15-
cooldown_epochs=0,
16-
warmup_updates=0,
15+
cooldown_t=0,
16+
warmup_t=0,
1717
warmup_lr_init=0,
1818
lr_min=0,
19+
mode='min',
20+
initialize=True,
1921
):
20-
super().__init__(optimizer, 'lr', initialize=False)
22+
super().__init__(optimizer, 'lr', initialize=initialize)
2123

2224
self.lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
23-
self.optimizer.optimizer,
24-
patience=patience,
25-
factor=factor,
25+
self.optimizer,
26+
patience=patience_t,
27+
factor=decay_rate,
2628
verbose=verbose,
2729
threshold=threshold,
28-
cooldown=cooldown_epochs,
30+
cooldown=cooldown_t,
31+
mode=mode,
2932
min_lr=lr_min
3033
)
3134

32-
self.warmup_updates = warmup_updates
35+
self.warmup_t = warmup_t
3336
self.warmup_lr_init = warmup_lr_init
34-
35-
if self.warmup_updates:
36-
self.warmup_active = warmup_updates > 0 # this state updates with num_updates
37-
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_updates for v in self.base_values]
37+
if self.warmup_t:
38+
self.warmup_steps = [(v - warmup_lr_init) / self.warmup_t for v in self.base_values]
3839
super().update_groups(self.warmup_lr_init)
3940
else:
4041
self.warmup_steps = [1 for _ in self.base_values]
@@ -51,18 +52,9 @@ def load_state_dict(self, state_dict):
5152
self.lr_scheduler.last_epoch = state_dict['last_epoch']
5253

5354
# override the base class step fn completely
54-
def step(self, epoch, val_loss=None):
55-
"""Update the learning rate at the end of the given epoch."""
56-
if val_loss is not None and not self.warmup_active:
57-
self.lr_scheduler.step(val_loss, epoch)
58-
else:
59-
self.lr_scheduler.last_epoch = epoch
60-
61-
def get_update_values(self, num_updates: int):
62-
if num_updates < self.warmup_updates:
63-
lrs = [self.warmup_lr_init + num_updates * s for s in self.warmup_steps]
55+
def step(self, epoch, metric=None):
56+
if epoch <= self.warmup_t:
57+
lrs = [self.warmup_lr_init + epoch * s for s in self.warmup_steps]
58+
super().update_groups(lrs)
6459
else:
65-
self.warmup_active = False # warmup cancelled by first update past warmup_update count
66-
lrs = None # no change on update after warmup stage
67-
return lrs
68-
60+
self.lr_scheduler.step(metric, epoch)

timm/scheduler/scheduler.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,11 @@ class Scheduler:
2525
def __init__(self,
2626
optimizer: torch.optim.Optimizer,
2727
param_group_field: str,
28+
noise_range_t=None,
29+
noise_type='normal',
30+
noise_pct=0.67,
31+
noise_std=1.0,
32+
noise_seed=None,
2833
initialize: bool = True) -> None:
2934
self.optimizer = optimizer
3035
self.param_group_field = param_group_field
@@ -40,6 +45,11 @@ def __init__(self,
4045
raise KeyError(f"{self._initial_param_group_field} missing from param_groups[{i}]")
4146
self.base_values = [group[self._initial_param_group_field] for group in self.optimizer.param_groups]
4247
self.metric = None # any point to having this for all?
48+
self.noise_range_t = noise_range_t
49+
self.noise_pct = noise_pct
50+
self.noise_type = noise_type
51+
self.noise_std = noise_std
52+
self.noise_seed = noise_seed if noise_seed is not None else 42
4353
self.update_groups(self.base_values)
4454

4555
def state_dict(self) -> Dict[str, Any]:
@@ -58,16 +68,38 @@ def step(self, epoch: int, metric: float = None) -> None:
5868
self.metric = metric
5969
values = self.get_epoch_values(epoch)
6070
if values is not None:
71+
values = self._add_noise(values, epoch)
6172
self.update_groups(values)
6273

6374
def step_update(self, num_updates: int, metric: float = None):
6475
self.metric = metric
6576
values = self.get_update_values(num_updates)
6677
if values is not None:
78+
values = self._add_noise(values, num_updates)
6779
self.update_groups(values)
6880

6981
def update_groups(self, values):
7082
if not isinstance(values, (list, tuple)):
7183
values = [values] * len(self.optimizer.param_groups)
7284
for param_group, value in zip(self.optimizer.param_groups, values):
7385
param_group[self.param_group_field] = value
86+
87+
def _add_noise(self, lrs, t):
88+
if self.noise_range_t is not None:
89+
if isinstance(self.noise_range_t, (list, tuple)):
90+
apply_noise = self.noise_range_t[0] <= t < self.noise_range_t[1]
91+
else:
92+
apply_noise = t >= self.noise_range_t
93+
if apply_noise:
94+
g = torch.Generator()
95+
g.manual_seed(self.noise_seed + t)
96+
if self.noise_type == 'normal':
97+
while True:
98+
# resample if noise out of percent limit, brute force but shouldn't spin much
99+
noise = torch.randn(1, generator=g).item()
100+
if abs(noise) < self.noise_pct:
101+
break
102+
else:
103+
noise = 2 * (torch.rand(1, generator=g).item() - 0.5) * self.noise_pct
104+
lrs = [v + v * noise for v in lrs]
105+
return lrs

timm/scheduler/scheduler_factory.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,22 @@
11
from .cosine_lr import CosineLRScheduler
22
from .tanh_lr import TanhLRScheduler
33
from .step_lr import StepLRScheduler
4+
from .plateau_lr import PlateauLRScheduler
45

56

67
def create_scheduler(args, optimizer):
78
num_epochs = args.epochs
9+
10+
if args.lr_noise is not None:
11+
if isinstance(args.lr_noise, (list, tuple)):
12+
noise_range = [n * num_epochs for n in args.lr_noise]
13+
if len(noise_range) == 1:
14+
noise_range = noise_range[0]
15+
else:
16+
noise_range = args.lr_noise * num_epochs
17+
else:
18+
noise_range = None
19+
820
lr_scheduler = None
921
#FIXME expose cycle parms of the scheduler config to arguments
1022
if args.sched == 'cosine':
@@ -18,6 +30,10 @@ def create_scheduler(args, optimizer):
1830
warmup_t=args.warmup_epochs,
1931
cycle_limit=1,
2032
t_in_epochs=True,
33+
noise_range_t=noise_range,
34+
noise_pct=args.lr_noise_pct,
35+
noise_std=args.lr_noise_std,
36+
noise_seed=args.seed,
2137
)
2238
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
2339
elif args.sched == 'tanh':
@@ -30,6 +46,10 @@ def create_scheduler(args, optimizer):
3046
warmup_t=args.warmup_epochs,
3147
cycle_limit=1,
3248
t_in_epochs=True,
49+
noise_range_t=noise_range,
50+
noise_pct=args.lr_noise_pct,
51+
noise_std=args.lr_noise_std,
52+
noise_seed=args.seed,
3353
)
3454
num_epochs = lr_scheduler.get_cycle_length() + args.cooldown_epochs
3555
elif args.sched == 'step':
@@ -39,5 +59,20 @@ def create_scheduler(args, optimizer):
3959
decay_rate=args.decay_rate,
4060
warmup_lr_init=args.warmup_lr,
4161
warmup_t=args.warmup_epochs,
62+
noise_range_t=noise_range,
63+
noise_pct=args.lr_noise_pct,
64+
noise_std=args.lr_noise_std,
65+
noise_seed=args.seed,
4266
)
67+
elif args.sched == 'plateau':
68+
lr_scheduler = PlateauLRScheduler(
69+
optimizer,
70+
decay_rate=args.decay_rate,
71+
patience_t=args.patience_epochs,
72+
lr_min=args.min_lr,
73+
warmup_lr_init=args.warmup_lr,
74+
warmup_t=args.warmup_epochs,
75+
cooldown_t=args.cooldown_epochs,
76+
)
77+
4378
return lr_scheduler, num_epochs

timm/scheduler/step_lr.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,21 @@ class StepLRScheduler(Scheduler):
1010

1111
def __init__(self,
1212
optimizer: torch.optim.Optimizer,
13-
decay_t: int,
13+
decay_t: float,
1414
decay_rate: float = 1.,
1515
warmup_t=0,
1616
warmup_lr_init=0,
1717
t_in_epochs=True,
18-
initialize=True) -> None:
19-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
18+
noise_range_t=None,
19+
noise_pct=0.67,
20+
noise_std=1.0,
21+
noise_seed=42,
22+
initialize=True,
23+
) -> None:
24+
super().__init__(
25+
optimizer, param_group_field="lr",
26+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
27+
initialize=initialize)
2028

2129
self.decay_t = decay_t
2230
self.decay_rate = decay_rate
@@ -33,8 +41,7 @@ def _get_lr(self, t):
3341
if t < self.warmup_t:
3442
lrs = [self.warmup_lr_init + t * s for s in self.warmup_steps]
3543
else:
36-
lrs = [v * (self.decay_rate ** (t // self.decay_t))
37-
for v in self.base_values]
44+
lrs = [v * (self.decay_rate ** (t // self.decay_t)) for v in self.base_values]
3845
return lrs
3946

4047
def get_epoch_values(self, epoch: int):

timm/scheduler/tanh_lr.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,15 @@ def __init__(self,
2828
warmup_prefix=False,
2929
cycle_limit=0,
3030
t_in_epochs=True,
31+
noise_range_t=None,
32+
noise_pct=0.67,
33+
noise_std=1.0,
34+
noise_seed=42,
3135
initialize=True) -> None:
32-
super().__init__(optimizer, param_group_field="lr", initialize=initialize)
36+
super().__init__(
37+
optimizer, param_group_field="lr",
38+
noise_range_t=noise_range_t, noise_pct=noise_pct, noise_std=noise_std, noise_seed=noise_seed,
39+
initialize=initialize)
3340

3441
assert t_initial > 0
3542
assert lr_min >= 0

0 commit comments

Comments
 (0)