Skip to content

Commit 168137b

Browse files
authored
Merge pull request #2538 from huggingface/mnv5_bias_str_norm
Add stem_bias option to MNV5. Resolve the norm layer so can pass string.
2 parents 72b9752 + e1c158e commit 168137b

File tree

1 file changed

+24
-6
lines changed

1 file changed

+24
-6
lines changed

timm/models/mobilenetv5.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,25 @@
77

88
from timm.data import IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD
99
from timm.layers import (
10-
SelectAdaptivePool2d, Linear, LayerType, PadType, RmsNorm2d, ConvNormAct, create_conv2d, get_norm_act_layer,
11-
to_2tuple
10+
SelectAdaptivePool2d,
11+
Linear,
12+
LayerType,
13+
RmsNorm2d,
14+
ConvNormAct,
15+
create_conv2d,
16+
get_norm_layer,
17+
get_norm_act_layer,
18+
to_2tuple,
1219
)
1320
from ._builder import build_model_with_cfg
1421
from ._efficientnet_blocks import SqueezeExcite, UniversalInvertedResidual
15-
from ._efficientnet_builder import BlockArgs, EfficientNetBuilder, decode_arch_def, efficientnet_init_weights, \
16-
round_channels, resolve_act_layer
22+
from ._efficientnet_builder import (
23+
BlockArgs,
24+
EfficientNetBuilder,
25+
decode_arch_def,
26+
efficientnet_init_weights,
27+
round_channels,
28+
)
1729
from ._features import feature_take_indices
1830
from ._features_fx import register_notrace_module
1931
from ._manipulate import checkpoint_seq, checkpoint
@@ -115,6 +127,7 @@ def __init__(
115127
num_classes: int = 1000,
116128
in_chans: int = 3,
117129
stem_size: int = 16,
130+
stem_bias: bool = False,
118131
fix_stem: bool = False,
119132
num_features: int = 2048,
120133
pad_type: str = '',
@@ -155,7 +168,7 @@ def __init__(
155168
"""
156169
super().__init__()
157170
act_layer = act_layer or nn.GELU
158-
norm_layer = norm_layer or RmsNorm2d
171+
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
159172
norm_act_layer = get_norm_act_layer(norm_layer, act_layer)
160173
se_layer = se_layer or SqueezeExcite
161174
self.num_classes = num_classes
@@ -173,6 +186,7 @@ def __init__(
173186
kernel_size=3,
174187
stride=2,
175188
padding=pad_type,
189+
bias=stem_bias,
176190
norm_layer=norm_layer,
177191
act_layer=act_layer,
178192
)
@@ -396,6 +410,7 @@ def __init__(
396410
block_args: BlockArgs,
397411
in_chans: int = 3,
398412
stem_size: int = 64,
413+
stem_bias: bool = False,
399414
fix_stem: bool = False,
400415
pad_type: str = '',
401416
msfa_indices: Sequence[int] = (-2, -1),
@@ -412,7 +427,7 @@ def __init__(
412427
):
413428
super().__init__()
414429
act_layer = act_layer or nn.GELU
415-
norm_layer = norm_layer or RmsNorm2d
430+
norm_layer = get_norm_layer(norm_layer) or RmsNorm2d
416431
se_layer = se_layer or SqueezeExcite
417432
self.num_classes = 0 # Exists to satisfy ._hub module APIs.
418433
self.drop_rate = drop_rate
@@ -427,6 +442,7 @@ def __init__(
427442
kernel_size=3,
428443
stride=2,
429444
padding=pad_type,
445+
bias=stem_bias,
430446
norm_layer=norm_layer,
431447
act_layer=act_layer,
432448
)
@@ -786,12 +802,14 @@ def _cfg(url: str = '', **kwargs):
786802
# encoder-only configs
787803
'mobilenetv5_300m_enc': _cfg(
788804
#hf_hub_id='timm/',
805+
mean=(0., 0., 0.), std=(1., 1., 1.),
789806
input_size=(3, 768, 768),
790807
num_classes=0),
791808

792809
# WIP classification configs for testing
793810
'mobilenetv5_300m': _cfg(
794811
# hf_hub_id='timm/',
812+
mean=(0., 0., 0.), std=(1., 1., 1.),
795813
input_size=(3, 768, 768),
796814
num_classes=0),
797815
'mobilenetv5_base.untrained': _cfg(

0 commit comments

Comments
 (0)