@@ -358,22 +358,33 @@ def __call__(self, in_chs, model_block_args):
358
358
return stages
359
359
360
360
361
- def _init_weight_goog (m , n = '' ):
361
+ def _init_weight_goog (m , n = '' , fix_group_fanout = False ):
362
362
""" Weight initialization as per Tensorflow official implementations.
363
363
364
+ Args:
365
+ m (nn.Module): module to init
366
+ n (str): module name
367
+ fix_group_fanout (bool): enable correct fanout calculation w/ group convs
368
+
369
+ FIXME change fix_group_fanout to default to True if experiments show better training results
370
+
364
371
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
365
372
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py
366
373
* https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/efficientnet_model.py
367
374
"""
368
375
if isinstance (m , CondConv2d ):
369
376
fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
377
+ if fix_group_fanout :
378
+ fan_out //= m .groups
370
379
init_weight_fn = get_condconv_initializer (
371
380
lambda w : w .data .normal_ (0 , math .sqrt (2.0 / fan_out )), m .num_experts , m .weight_shape )
372
381
init_weight_fn (m .weight )
373
382
if m .bias is not None :
374
383
m .bias .data .zero_ ()
375
384
elif isinstance (m , nn .Conv2d ):
376
385
fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
386
+ if fix_group_fanout :
387
+ fan_out //= m .groups
377
388
m .weight .data .normal_ (0 , math .sqrt (2.0 / fan_out ))
378
389
if m .bias is not None :
379
390
m .bias .data .zero_ ()
@@ -390,21 +401,6 @@ def _init_weight_goog(m, n=''):
390
401
m .bias .data .zero_ ()
391
402
392
403
393
- def _init_weight_default (m , n = '' ):
394
- """ Basic ResNet (Kaiming) style weight init"""
395
- if isinstance (m , CondConv2d ):
396
- init_fn = get_condconv_initializer (partial (
397
- nn .init .kaiming_normal_ , mode = 'fan_out' , nonlinearity = 'relu' ), m .num_experts , m .weight_shape )
398
- init_fn (m .weight )
399
- elif isinstance (m , nn .Conv2d ):
400
- nn .init .kaiming_normal_ (m .weight , mode = 'fan_out' , nonlinearity = 'relu' )
401
- elif isinstance (m , nn .BatchNorm2d ):
402
- m .weight .data .fill_ (1.0 )
403
- m .bias .data .zero_ ()
404
- elif isinstance (m , nn .Linear ):
405
- nn .init .kaiming_uniform_ (m .weight , mode = 'fan_in' , nonlinearity = 'linear' )
406
-
407
-
408
404
def efficientnet_init_weights (model : nn .Module , init_fn = None ):
409
405
init_fn = init_fn or _init_weight_goog
410
406
for n , m in model .named_modules ():
0 commit comments