Skip to content

Commit ea4f940

Browse files
RyanMullinsrwightman
authored andcommitted
fix: mnv5 with conv_stem bias and GELU approx tanh
1 parent 446e8a8 commit ea4f940

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

timm/models/mobilenetv5.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@
2828
)
2929
from ._features import feature_take_indices
3030
from ._features_fx import register_notrace_module
31-
from ._manipulate import checkpoint_seq, checkpoint
31+
from ._manipulate import checkpoint_seq
3232
from ._registry import generate_default_cfgs, register_model
3333

3434
__all__ = ['MobileNetV5', 'MobileNetV5Encoder']
3535

36+
_GELU = partial(nn.GELU, approximate='tanh')
37+
3638

3739
@register_notrace_module
3840
class MobileNetV5MultiScaleFusionAdapter(nn.Module):
@@ -68,7 +70,7 @@ def __init__(
6870
self.layer_scale_init_value = layer_scale_init_value
6971
self.noskip = noskip
7072

71-
act_layer = act_layer or nn.GELU
73+
act_layer = act_layer or _GELU
7274
norm_layer = norm_layer or RmsNorm2d
7375
self.ffn = UniversalInvertedResidual(
7476
in_chs=self.in_channels,
@@ -167,7 +169,7 @@ def __init__(
167169
global_pool: Type of pooling to use for global pooling features of the FC head.
168170
"""
169171
super().__init__()
170-
act_layer = act_layer or nn.GELU
172+
act_layer = act_layer or _GELU
171173
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
172174
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
173175
se_layer = se_layer or SqueezeExcite
@@ -410,7 +412,7 @@ def __init__(
410412
block_args: BlockArgs,
411413
in_chans: int = 3,
412414
stem_size: int = 64,
413-
stem_bias: bool = False,
415+
stem_bias: bool = True,
414416
fix_stem: bool = False,
415417
pad_type: str = '',
416418
msfa_indices: Sequence[int] = (-2, -1),
@@ -426,7 +428,7 @@ def __init__(
426428
layer_scale_init_value: Optional[float] = None,
427429
):
428430
super().__init__()
429-
act_layer = act_layer or nn.GELU
431+
act_layer = act_layer or _GELU
430432
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
431433
se_layer = se_layer or SqueezeExcite
432434
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -526,6 +528,7 @@ def forward_intermediates(
526528
feat_idx = 0 # stem is index 0
527529
x = self.conv_stem(x)
528530
if feat_idx in take_indices:
531+
print("conv_stem is captured")
529532
intermediates.append(x)
530533
if feat_idx in self.msfa_indices:
531534
msfa_intermediates.append(x)
@@ -777,7 +780,7 @@ def _gen_mobilenet_v5(
777780
fix_stem=channel_multiplier < 1.0,
778781
round_chs_fn=partial(round_channels, multiplier=channel_multiplier),
779782
norm_layer=RmsNorm2d,
780-
act_layer=nn.GELU,
783+
act_layer=_GELU,
781784
layer_scale_init_value=1e-5,
782785
)
783786
model_kwargs = dict(model_kwargs, **kwargs)

0 commit comments

Comments
 (0)