Skip to content

Commit 9fee316

Browse files
committed
Enable fixed fanout calc in EfficientNet/MobileNetV3 weight init by default. Fix #84
1 parent 27b3680 commit 9fee316

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

timm/models/efficientnet_builder.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,15 +359,13 @@ def __call__(self, in_chs, model_block_args):
359359
return stages
360360

361361

362-
def _init_weight_goog(m, n='', fix_group_fanout=False):
362+
def _init_weight_goog(m, n='', fix_group_fanout=True):
363363
""" Weight initialization as per Tensorflow official implementations.
364364
365365
Args:
366366
m (nn.Module): module to init
367367
n (str): module name
368-
fix_group_fanout (bool): enable correct fanout calculation w/ group convs
369-
370-
FIXME change fix_group_fanout to default to True if experiments show better training results
368+
fix_group_fanout (bool): enable correct (matching Tensorflow TPU impl) fanout calculation w/ group convs
371369
372370
Handles layers in EfficientNet, EfficientNet-CondConv, MixNet, MnasNet, MobileNetV3, etc:
373371
* https://github.com/tensorflow/tpu/blob/master/models/official/mnasnet/mnasnet_model.py

0 commit comments

Comments
 (0)