7
7
8
8
from timm .data import IMAGENET_INCEPTION_MEAN , IMAGENET_INCEPTION_STD
9
9
from timm .layers import (
10
- SelectAdaptivePool2d , Linear , LayerType , PadType , RmsNorm2d , ConvNormAct , create_conv2d , get_norm_act_layer ,
11
- to_2tuple
10
+ SelectAdaptivePool2d ,
11
+ Linear ,
12
+ LayerType ,
13
+ RmsNorm2d ,
14
+ ConvNormAct ,
15
+ create_conv2d ,
16
+ get_norm_layer ,
17
+ get_norm_act_layer ,
18
+ to_2tuple ,
12
19
)
13
20
from ._builder import build_model_with_cfg
14
21
from ._efficientnet_blocks import SqueezeExcite , UniversalInvertedResidual
15
- from ._efficientnet_builder import BlockArgs , EfficientNetBuilder , decode_arch_def , efficientnet_init_weights , \
16
- round_channels , resolve_act_layer
22
+ from ._efficientnet_builder import (
23
+ BlockArgs ,
24
+ EfficientNetBuilder ,
25
+ decode_arch_def ,
26
+ efficientnet_init_weights ,
27
+ round_channels ,
28
+ )
17
29
from ._features import feature_take_indices
18
30
from ._features_fx import register_notrace_module
19
31
from ._manipulate import checkpoint_seq , checkpoint
@@ -115,6 +127,7 @@ def __init__(
115
127
num_classes : int = 1000 ,
116
128
in_chans : int = 3 ,
117
129
stem_size : int = 16 ,
130
+ stem_bias : bool = False ,
118
131
fix_stem : bool = False ,
119
132
num_features : int = 2048 ,
120
133
pad_type : str = '' ,
@@ -155,7 +168,7 @@ def __init__(
155
168
"""
156
169
super ().__init__ ()
157
170
act_layer = act_layer or nn .GELU
158
- norm_layer = norm_layer or RmsNorm2d
171
+ norm_layer = get_norm_layer ( norm_layer ) or RmsNorm2d
159
172
norm_act_layer = get_norm_act_layer (norm_layer , act_layer )
160
173
se_layer = se_layer or SqueezeExcite
161
174
self .num_classes = num_classes
@@ -173,6 +186,7 @@ def __init__(
173
186
kernel_size = 3 ,
174
187
stride = 2 ,
175
188
padding = pad_type ,
189
+ bias = stem_bias ,
176
190
norm_layer = norm_layer ,
177
191
act_layer = act_layer ,
178
192
)
@@ -396,6 +410,7 @@ def __init__(
396
410
block_args : BlockArgs ,
397
411
in_chans : int = 3 ,
398
412
stem_size : int = 64 ,
413
+ stem_bias : bool = False ,
399
414
fix_stem : bool = False ,
400
415
pad_type : str = '' ,
401
416
msfa_indices : Sequence [int ] = (- 2 , - 1 ),
@@ -412,7 +427,7 @@ def __init__(
412
427
):
413
428
super ().__init__ ()
414
429
act_layer = act_layer or nn .GELU
415
- norm_layer = norm_layer or RmsNorm2d
430
+ norm_layer = get_norm_layer ( norm_layer ) or RmsNorm2d
416
431
se_layer = se_layer or SqueezeExcite
417
432
self .num_classes = 0 # Exists to satisfy ._hub module APIs.
418
433
self .drop_rate = drop_rate
@@ -427,6 +442,7 @@ def __init__(
427
442
kernel_size = 3 ,
428
443
stride = 2 ,
429
444
padding = pad_type ,
445
+ bias = stem_bias ,
430
446
norm_layer = norm_layer ,
431
447
act_layer = act_layer ,
432
448
)
@@ -786,12 +802,14 @@ def _cfg(url: str = '', **kwargs):
786
802
# encoder-only configs
787
803
'mobilenetv5_300m_enc' : _cfg (
788
804
#hf_hub_id='timm/',
805
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
789
806
input_size = (3 , 768 , 768 ),
790
807
num_classes = 0 ),
791
808
792
809
# WIP classification configs for testing
793
810
'mobilenetv5_300m' : _cfg (
794
811
# hf_hub_id='timm/',
812
+ mean = (0. , 0. , 0. ), std = (1. , 1. , 1. ),
795
813
input_size = (3 , 768 , 768 ),
796
814
num_classes = 0 ),
797
815
'mobilenetv5_base.untrained' : _cfg (
0 commit comments