28
28
)
29
29
from ._features import feature_take_indices
30
30
from ._features_fx import register_notrace_module
31
- from ._manipulate import checkpoint_seq , checkpoint
31
+ from ._manipulate import checkpoint_seq
32
32
from ._registry import generate_default_cfgs , register_model
33
33
34
34
__all__ = ['MobileNetV5' , 'MobileNetV5Encoder' ]
35
35
36
+ _GELU = partial (nn .GELU , approximate = 'tanh' )
37
+
36
38
37
39
@register_notrace_module
38
40
class MobileNetV5MultiScaleFusionAdapter (nn .Module ):
@@ -68,7 +70,7 @@ def __init__(
68
70
self .layer_scale_init_value = layer_scale_init_value
69
71
self .noskip = noskip
70
72
71
- act_layer = act_layer or nn . GELU
73
+ act_layer = act_layer or _GELU
72
74
norm_layer = norm_layer or RmsNorm2d
73
75
self .ffn = UniversalInvertedResidual (
74
76
in_chs = self .in_channels ,
@@ -167,7 +169,7 @@ def __init__(
167
169
global_pool: Type of pooling to use for global pooling features of the FC head.
168
170
"""
169
171
super ().__init__ ()
170
- act_layer = act_layer or nn . GELU
172
+ act_layer = act_layer or _GELU
171
173
norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
172
174
norm_act_layer = get_norm_act_layer (norm_layer , act_layer )
173
175
se_layer = se_layer or SqueezeExcite
@@ -410,7 +412,7 @@ def __init__(
410
412
block_args : BlockArgs ,
411
413
in_chans : int = 3 ,
412
414
stem_size : int = 64 ,
413
- stem_bias : bool = False ,
415
+ stem_bias : bool = True ,
414
416
fix_stem : bool = False ,
415
417
pad_type : str = '' ,
416
418
msfa_indices : Sequence [int ] = (- 2 , - 1 ),
@@ -426,7 +428,7 @@ def __init__(
426
428
layer_scale_init_value : Optional [float ] = None ,
427
429
):
428
430
super ().__init__ ()
429
- act_layer = act_layer or nn . GELU
431
+ act_layer = act_layer or _GELU
430
432
norm_layer = get_norm_layer (norm_layer ) or RmsNorm2d
431
433
se_layer = se_layer or SqueezeExcite
432
434
self .num_classes = 0 # Exists to satisfy ._hub module APIs.
@@ -526,6 +528,7 @@ def forward_intermediates(
526
528
feat_idx = 0 # stem is index 0
527
529
x = self .conv_stem (x )
528
530
if feat_idx in take_indices :
531
+ print ("conv_stem is captured" )
529
532
intermediates .append (x )
530
533
if feat_idx in self .msfa_indices :
531
534
msfa_intermediates .append (x )
@@ -777,7 +780,7 @@ def _gen_mobilenet_v5(
777
780
fix_stem = channel_multiplier < 1.0 ,
778
781
round_chs_fn = partial (round_channels , multiplier = channel_multiplier ),
779
782
norm_layer = RmsNorm2d ,
780
- act_layer = nn . GELU ,
783
+ act_layer = _GELU ,
781
784
layer_scale_init_value = 1e-5 ,
782
785
)
783
786
model_kwargs = dict (model_kwargs , ** kwargs )
0 commit comments