@@ -234,6 +234,8 @@ def create_optimizer(
234
234
foreach : Optional [bool ] = None ,
235
235
weight_decay_exclude_1d : bool = True ,
236
236
layer_decay : Optional [float ] = None ,
237
+ layer_decay_min_scale : Optional [float ] = None ,
238
+ layer_decay_no_opt_scale : Optional [float ] = None ,
237
239
param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
238
240
** kwargs : Any ,
239
241
) -> torch .optim .Optimizer :
@@ -248,6 +250,8 @@ def create_optimizer(
248
250
foreach: Enable/disable foreach operation
249
251
weight_decay_exclude_1d: Whether to skip weight decay for 1d params (biases and norm affine)
250
252
layer_decay: Layer-wise learning rate decay
253
+ layer_scale_min_scale: Minimum layer scale factor clamp value
254
+ layer_scale_no_opt_scale: Layer scale below which optimization is disabled
251
255
param_group_fn: Optional custom parameter grouping function
252
256
**kwargs: Additional optimizer-specific arguments
253
257
@@ -273,6 +277,8 @@ def create_optimizer(
273
277
layer_decay = layer_decay ,
274
278
no_weight_decay_list = no_weight_decay ,
275
279
weight_decay_exclude_1d = weight_decay_exclude_1d ,
280
+ min_scale = layer_decay_min_scale ,
281
+ no_opt_scale = layer_decay_no_opt_scale ,
276
282
)
277
283
weight_decay = 0.
278
284
elif weight_decay and weight_decay_exclude_1d :
@@ -1140,6 +1146,8 @@ def create_optimizer_v2(
1140
1146
foreach : Optional [bool ] = None ,
1141
1147
filter_bias_and_bn : bool = True ,
1142
1148
layer_decay : Optional [float ] = None ,
1149
+ layer_decay_min_scale : float = 0.0 ,
1150
+ layer_decay_no_opt_scale : Optional [float ] = None ,
1143
1151
param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
1144
1152
** kwargs : Any ,
1145
1153
) -> torch .optim .Optimizer :
@@ -1215,31 +1223,36 @@ def create_optimizer_v2(
1215
1223
foreach = foreach ,
1216
1224
weight_decay_exclude_1d = filter_bias_and_bn ,
1217
1225
layer_decay = layer_decay ,
1226
+ layer_decay_min_scale = layer_decay_min_scale ,
1227
+ layer_decay_no_opt_scale = layer_decay_no_opt_scale ,
1218
1228
param_group_fn = param_group_fn ,
1219
1229
** kwargs
1220
1230
)
1221
1231
1222
1232
1223
1233
def optimizer_kwargs (cfg ):
1224
- """ cfg/argparse to kwargs helper
1225
- Convert optimizer args in argparse args or cfg like object to keyword args for updated create fn.
1226
- """
1227
- kwargs = dict (
1228
- opt = cfg .opt ,
1229
- lr = cfg .lr ,
1230
- weight_decay = cfg .weight_decay ,
1231
- momentum = cfg .momentum ,
1232
- )
1233
- if getattr (cfg , 'opt_eps' , None ) is not None :
1234
- kwargs ['eps' ] = cfg .opt_eps
1235
- if getattr (cfg , 'opt_betas' , None ) is not None :
1236
- kwargs ['betas' ] = cfg .opt_betas
1237
- if getattr (cfg , 'layer_decay' , None ) is not None :
1238
- kwargs ['layer_decay' ] = cfg .layer_decay
1239
- if getattr (cfg , 'opt_args' , None ) is not None :
1240
- kwargs .update (cfg .opt_args )
1241
- if getattr (cfg , 'opt_foreach' , None ) is not None :
1242
- kwargs ['foreach' ] = cfg .opt_foreach
1234
+ """Convert argparse-style `cfg` object to kwargs for an optimizer factory."""
1235
+ kwargs = {
1236
+ 'opt' : cfg .opt ,
1237
+ 'lr' : cfg .lr ,
1238
+ 'weight_decay' : cfg .weight_decay ,
1239
+ 'momentum' : cfg .momentum ,
1240
+ }
1241
+ if (eps := getattr (cfg , 'opt_eps' , None )) is not None :
1242
+ kwargs ['eps' ] = eps
1243
+ if (betas := getattr (cfg , 'opt_betas' , None )) is not None :
1244
+ kwargs ['betas' ] = betas
1245
+ if (layer_decay := getattr (cfg , 'layer_decay' , None )) is not None :
1246
+ kwargs ['layer_decay' ] = layer_decay
1247
+ if (ld_min := getattr (cfg , 'layer_decay_min_scale' , None )) is not None :
1248
+ kwargs ['layer_decay_min_scale' ] = ld_min
1249
+ if (ld_no_opt := getattr (cfg , 'layer_decay_no_opt_scale' , None )) is not None :
1250
+ kwargs ['layer_decay_no_opt_scale' ] = ld_no_opt
1251
+ if (opt_args := getattr (cfg , 'opt_args' , None )) is not None :
1252
+ kwargs .update (opt_args )
1253
+ if (foreach := getattr (cfg , 'opt_foreach' , None )) is not None :
1254
+ kwargs ['foreach' ] = foreach
1255
+
1243
1256
return kwargs
1244
1257
1245
1258
0 commit comments