From a71976448156da3d93b7aa0e91108896858c12dd Mon Sep 17 00:00:00 2001 From: samithuang <285365963@qq.com> Date: Wed, 12 Apr 2023 09:50:20 +0000 Subject: [PATCH] fix conflict of param grouping and filter bn bias --- mindcv/optim/optim_factory.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/mindcv/optim/optim_factory.py b/mindcv/optim/optim_factory.py index 7fe6bf282..b8322bd1c 100644 --- a/mindcv/optim/optim_factory.py +++ b/mindcv/optim/optim_factory.py @@ -13,6 +13,9 @@ def init_group_params(params, weight_decay): + """ + Filter bias and norm layers including nn.BatchNorm and nn.LayerNorm from weight decay. + """ decay_params = [] no_decay_params = [] @@ -60,7 +63,9 @@ def create_optimizer( momentum: momentum if the optimizer supports. Default: 0.9. nesterov: Whether to use Nesterov Accelerated Gradient (NAG) algorithm to update the gradients. Default: False. filter_bias_and_bn: whether to filter batch norm parameters and bias from weight decay. - If True, weight decay will not apply on BN parameters and bias in Conv or Dense layers. Default: True. + If True and `param` is of type list[Parameter] (no param group defined), weight decay + will not be applied on parameters containing "beta", "gamma", and "bias", i.e., BatchNorm + and LayerNorm parameters and bias in Conv or Dense layers. Default: True. loss_scale: A floating point value for the loss scale, which must be larger than 0.0. Default: 1.0. Returns: @@ -69,8 +74,19 @@ def create_optimizer( opt = opt.lower() + # check whether param grouping strategy is encoded in `params` + customized_param_group = False + if isinstance(params[0], dict): + customized_param_group = True + if weight_decay and filter_bias_and_bn: - params = init_group_params(params, weight_decay) + if not customized_param_group: + params = init_group_params(params, weight_decay) + else: + print( + "WARNING: Customized param grouping startegy detected in `params`." + "filter_bias_and_bn (default=True) will be disabled" + ) opt_args = dict(**kwargs) # if lr is not None: