Skip to content

Commit d0eb59e

Browse files
committed
Remove unused default_init for EfficientNets, experimenting with fanout calc for #84
1 parent cade829 commit d0eb59e

File tree

1 file changed

+12
-16
lines changed

1 file changed

+12
-16
lines changed

timm/models/efficientnet_builder.py

Lines changed: 12 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -358,22 +358,33 @@ def __call__(self, in_chs, model_block_args):
358358
return stages
359359

360360

361-
def _init_weight_goog(m, n=''):
361+
def _init_weight_goog(m, n='', fix_group_fanout=False):
362362
""" Weight initialization as per Tensorflow official implementations.
363363
364+
Args:
365+
m (nn.Module): module to init
366+
n (str): module name
367+
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
368+
369+
FIXME change fix_group_fanout to default to True if experiments show better training results
370+
364371
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
365372
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
366373
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
367374
"""
368375
if isinstance(m, CondConv2d):
369376
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
377+
if fix_group_fanout:
378+
fan_out //= m.groups
370379
init_weight_fn = get_condconv_initializer(
371380
lambda w: w.data.normal_(0, math.sqrt(2.0 / fan_out)), m.num_experts, m.weight_shape)
372381
init_weight_fn(m.weight)
373382
if m.bias is not None:
374383
m.bias.data.zero_()
375384
elif isinstance(m, nn.Conv2d):
376385
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
386+
if fix_group_fanout:
387+
fan_out //= m.groups
377388
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
378389
if m.bias is not None:
379390
m.bias.data.zero_()
@@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
390401
m.bias.data.zero_()
391402

392403

393-
def _init_weight_default(m, n=''):
394-
""" Basic ResNet (Kaiming) style weight init"""
395-
if isinstance(m, CondConv2d):
396-
init_fn = get_condconv_initializer(partial(
397-
nn.init.kaiming_normal_, mode='fan_out', nonlinearity='relu'), m.num_experts, m.weight_shape)
398-
init_fn(m.weight)
399-
elif isinstance(m, nn.Conv2d):
400-
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
401-
elif isinstance(m, nn.BatchNorm2d):
402-
m.weight.data.fill_(1.0)
403-
m.bias.data.zero_()
404-
elif isinstance(m, nn.Linear):
405-
nn.init.kaiming_uniform_(m.weight, mode='fan_in', nonlinearity='linear')
406-
407-
408404
def efficientnet_init_weights(model: nn.Module, init_fn=None):
409405
init_fn = init_fn or _init_weight_goog
410406
for n, m in model.named_modules():

0 commit comments

Comments
 (0)