Skip to content

Commit 72b9752

Browse files
authored
Merge pull request #2537 from huggingface/no_opt_layer_decay
Add a min layer-decay scale clamp, and no optimization threshold to exclude groups from optimization
2 parents 6239313 + 83709ae commit 72b9752

File tree

3 files changed

+45
-22
lines changed

3 files changed

+45
-22
lines changed

timm/optim/_optim_factory.py

Lines changed: 32 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,8 @@ def create_optimizer(
234234
foreach: Optional[bool] = None,
235235
weight_decay_exclude_1d: bool = True,
236236
layer_decay: Optional[float] = None,
237+
layer_decay_min_scale: Optional[float] = None,
238+
layer_decay_no_opt_scale: Optional[float] = None,
237239
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
238240
**kwargs: Any,
239241
) -> torch.optim.Optimizer:
@@ -248,6 +250,8 @@ def create_optimizer(
248250
foreach: Enable/disable foreach operation
249251
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
250252
layer_decay: Layer-wise learning rate decay
253+
layer_scale_min_scale: Minimum layer scale factor clamp value
254+
layer_scale_no_opt_scale: Layer scale below which optimization is disabled
251255
param_group_fn: Optional custom parameter grouping function
252256
**kwargs: Additional optimizer-specific arguments
253257
@@ -273,6 +277,8 @@ def create_optimizer(
273277
layer_decay=layer_decay,
274278
no_weight_decay_list=no_weight_decay,
275279
weight_decay_exclude_1d=weight_decay_exclude_1d,
280+
min_scale=layer_decay_min_scale,
281+
no_opt_scale=layer_decay_no_opt_scale,
276282
)
277283
weight_decay = 0.
278284
elif weight_decay and weight_decay_exclude_1d:
@@ -1140,6 +1146,8 @@ def create_optimizer_v2(
11401146
foreach: Optional[bool] = None,
11411147
filter_bias_and_bn: bool = True,
11421148
layer_decay: Optional[float] = None,
1149+
layer_decay_min_scale: float = 0.0,
1150+
layer_decay_no_opt_scale: Optional[float] = None,
11431151
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
11441152
**kwargs: Any,
11451153
) -> torch.optim.Optimizer:
@@ -1215,31 +1223,36 @@ def create_optimizer_v2(
12151223
foreach=foreach,
12161224
weight_decay_exclude_1d=filter_bias_and_bn,
12171225
layer_decay=layer_decay,
1226+
layer_decay_min_scale=layer_decay_min_scale,
1227+
layer_decay_no_opt_scale=layer_decay_no_opt_scale,
12181228
param_group_fn=param_group_fn,
12191229
**kwargs
12201230
)
12211231

12221232

12231233
def optimizer_kwargs(cfg):
1224-
""" cfg/argparse to kwargs helper
1225-
Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
1226-
"""
1227-
kwargs = dict(
1228-
opt=cfg.opt,
1229-
lr=cfg.lr,
1230-
weight_decay=cfg.weight_decay,
1231-
momentum=cfg.momentum,
1232-
)
1233-
if getattr(cfg, 'opt_eps', None) is not None:
1234-
kwargs['eps'] = cfg.opt_eps
1235-
if getattr(cfg, 'opt_betas', None) is not None:
1236-
kwargs['betas'] = cfg.opt_betas
1237-
if getattr(cfg, 'layer_decay', None) is not None:
1238-
kwargs['layer_decay'] = cfg.layer_decay
1239-
if getattr(cfg, 'opt_args', None) is not None:
1240-
kwargs.update(cfg.opt_args)
1241-
if getattr(cfg, 'opt_foreach', None) is not None:
1242-
kwargs['foreach'] = cfg.opt_foreach
1234+
"""Convert argparse-style `cfg` object to kwargs for an optimizer factory."""
1235+
kwargs = {
1236+
'opt': cfg.opt,
1237+
'lr': cfg.lr,
1238+
'weight_decay': cfg.weight_decay,
1239+
'momentum': cfg.momentum,
1240+
}
1241+
if (eps := getattr(cfg, 'opt_eps', None)) is not None:
1242+
kwargs['eps'] = eps
1243+
if (betas := getattr(cfg, 'opt_betas', None)) is not None:
1244+
kwargs['betas'] = betas
1245+
if (layer_decay := getattr(cfg, 'layer_decay', None)) is not None:
1246+
kwargs['layer_decay'] = layer_decay
1247+
if (ld_min := getattr(cfg, 'layer_decay_min_scale', None)) is not None:
1248+
kwargs['layer_decay_min_scale'] = ld_min
1249+
if (ld_no_opt := getattr(cfg, 'layer_decay_no_opt_scale', None)) is not None:
1250+
kwargs['layer_decay_no_opt_scale'] = ld_no_opt
1251+
if (opt_args := getattr(cfg, 'opt_args', None)) is not None:
1252+
kwargs.update(opt_args)
1253+
if (foreach := getattr(cfg, 'opt_foreach', None)) is not None:
1254+
kwargs['foreach'] = foreach
1255+
12431256
return kwargs
12441257

12451258

timm/optim/_param_groups.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ def param_groups_layer_decay(
7373
weight_decay_exclude_1d: bool = True,
7474
layer_decay: float = .75,
7575
end_layer_decay: Optional[float] = None,
76+
min_scale: float = 0.,
77+
no_opt_scale: Optional[float] = None,
7678
verbose: bool = False,
7779
):
7880
"""
@@ -91,7 +93,7 @@ def param_groups_layer_decay(
9193
layer_map = auto_group_layers(model)
9294
num_layers = max(layer_map.values()) + 1
9395
layer_max = num_layers - 1
94-
layer_scales = list(layer_decay ** (layer_max - i) for i in range(num_layers))
96+
layer_scales = list(max(min_scale, layer_decay ** (layer_max - i)) for i in range(num_layers))
9597

9698
for name, param in model.named_parameters():
9799
if not param.requires_grad:
@@ -106,10 +108,14 @@ def param_groups_layer_decay(
106108
this_decay = weight_decay
107109

108110
layer_id = layer_map.get(name, layer_max)
109-
group_name = "layer_%d_%s" % (layer_id, g_decay)
111+
this_scale = layer_scales[layer_id]
112+
if no_opt_scale and this_scale < no_opt_scale:
113+
# if the calculated layer scale is below this, exclude from optimization
114+
param.requires_grad = False
115+
continue
110116

117+
group_name = "layer_%d_%s" % (layer_id, g_decay)
111118
if group_name not in param_groups:
112-
this_scale = layer_scales[layer_id]
113119
param_group_names[group_name] = {
114120
"lr_scale": this_scale,
115121
"weight_decay": this_decay,

train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,10 @@
206206
help='Gradient clipping mode. One of ("norm", "value", "agc")')
207207
group.add_argument('--layer-decay', type=float, default=None,
208208
help='layer-wise learning rate decay (default: None)')
209+
group.add_argument('--layer-decay-min-scale', type=float, default=0,
210+
help='layer-wise lr decay minimum scale clamp (default: 0)')
211+
group.add_argument('--layer-decay-no-opt-scale', type=float, default=None,
212+
help='layer-wise lr decay no optimization scale (default: None)')
209213
group.add_argument('--opt-kwargs', nargs='*', default={}, action=utils.ParseKwargs)
210214

211215
# Learning rate schedule parameters

0 commit comments

Comments
 (0)