Skip to content

Commit 1f69a52

Browse files
authored
Merge pull request #2527 from huggingface/mobilenetv5
MobileNetV5
2 parents 61be921 + 3828676 commit 1f69a52

File tree

8 files changed

+982
-13
lines changed

8 files changed

+982
-13
lines changed

tests/test_models.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def test_model_backward(model_name, batch_size):
210210
pytest.skip("Fixed input size model > limit.")
211211

212212
model = create_model(model_name, pretrained=False, num_classes=42)
213+
encoder_only = model.num_classes == 0 # FIXME better approach?
213214
num_params = sum([x.numel() for x in model.parameters()])
214215
model.train()
215216

@@ -224,7 +225,12 @@ def test_model_backward(model_name, batch_size):
224225
assert x.grad is not None, f'No gradient for {n}'
225226
num_grad = sum([x.grad.numel() for x in model.parameters() if x.grad is not None])
226227

227-
assert outputs.shape[-1] == 42
228+
if encoder_only:
229+
output_fmt = getattr(model, 'output_fmt', 'NCHW')
230+
feat_axis = get_channel_dim(output_fmt)
231+
assert outputs.shape[feat_axis] == model.num_features, f'unpooled feature dim {outputs.shape[feat_axis]} != model.num_features {model.num_features}'
232+
else:
233+
assert outputs.shape[-1] == 42
228234
assert num_params == num_grad, 'Some parameters are missing gradients'
229235
assert not torch.isnan(outputs).any(), 'Output included NaNs'
230236

timm/layers/create_norm_act.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from .evo_norm import *
1313
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14-
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d
14+
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d
1515
from .inplace_abn import InplaceAbn
1616

1717
_NORM_ACT_MAP = dict(
@@ -34,11 +34,21 @@
3434
frntlu=FilterResponseNormTlu2d,
3535
inplaceabn=InplaceAbn,
3636
iabn=InplaceAbn,
37+
rmsnorm=RmsNormAct,
38+
rmsnorm2d=RmsNormAct2d,
3739
)
3840
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
3941
# has act_layer arg to define act type
4042
_NORM_ACT_REQUIRES_ARG = {
41-
BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, FilterResponseNormAct2d, InplaceAbn}
43+
BatchNormAct2d,
44+
GroupNormAct,
45+
LayerNormAct,
46+
LayerNormAct2d,
47+
FilterResponseNormAct2d,
48+
InplaceAbn,
49+
RmsNormAct,
50+
RmsNormAct2d,
51+
}
4252

4353

4454
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
@@ -83,6 +93,8 @@ def get_norm_act_layer(norm_layer, act_layer=None):
8393
norm_act_layer = LayerNormAct2d
8494
elif type_name.startswith('layernorm'):
8595
norm_act_layer = LayerNormAct
96+
elif type_name.startswith('rmsnorm2d'):
97+
norm_act_layer = RmsNormAct2d
8698
else:
8799
assert False, f"No equivalent norm_act layer for {type_name}"
88100

timm/layers/fast_norm.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def fast_rms_norm(
148148
return fused_rms_norm_affine(x, weight, normalized_shape, eps)
149149

150150
if is_autocast_enabled(x.device.type):
151-
# normally native AMP casts LN inputs to float32
151+
# normally native AMP casts LN inputs to float32 and leaves the output as float32
152152
# apex LN does not, this is behaving like Apex
153153
dt = get_autocast_dtype(x.device.type)
154154
x, weight = x.to(dt), weight.to(dt)
@@ -162,6 +162,51 @@ def fast_rms_norm(
162162
return x
163163

164164

165+
def rms_norm2d(
166+
x: torch.Tensor,
167+
normalized_shape: List[int],
168+
weight: Optional[torch.Tensor] = None,
169+
eps: float = 1e-5,
170+
):
171+
assert len(normalized_shape) == 1
172+
v = x.pow(2)
173+
v = torch.mean(v, dim=1, keepdim=True)
174+
x = x * torch.rsqrt(v + eps)
175+
if weight is not None:
176+
x = x * weight.reshape(1, -1, 1, 1)
177+
return x
178+
179+
180+
def fast_rms_norm2d(
181+
x: torch.Tensor,
182+
normalized_shape: List[int],
183+
weight: Optional[torch.Tensor] = None,
184+
eps: float = 1e-5,
185+
) -> torch.Tensor:
186+
if torch.jit.is_scripting():
187+
# this must be by itself, cannot merge with has_apex_rmsnorm
188+
return rms_norm2d(x, normalized_shape, weight, eps)
189+
190+
if has_apex_rmsnorm:
191+
x = x.permute(0, 2, 3, 1)
192+
if weight is None:
193+
x = fused_rms_norm(x, normalized_shape, eps)
194+
else:
195+
x = fused_rms_norm_affine(x, weight, normalized_shape, eps)
196+
x = x.permute(0, 3, 1, 2)
197+
198+
if is_autocast_enabled(x.device.type):
199+
# normally native AMP casts norm inputs to float32 and leaves the output as float32
200+
# apex does not, this is behaving like Apex
201+
dt = get_autocast_dtype(x.device.type)
202+
x, weight = x.to(dt), weight.to(dt)
203+
204+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
205+
x = rms_norm2d(x, normalized_shape, weight, eps)
206+
207+
return x
208+
209+
165210
def simple_norm(
166211
x: torch.Tensor,
167212
normalized_shape: List[int],

timm/layers/norm.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313

14-
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm, simple_norm
14+
from .fast_norm import (
15+
is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d,
16+
fast_simple_norm, simple_norm
17+
)
1518

1619
try:
1720
from torch.nn.functional import rms_norm
@@ -173,7 +176,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
173176

174177

175178
class RmsNorm2d(nn.Module):
176-
""" RmsNorm w/ fast (apex) norm if available
179+
""" RmsNorm2D for NCHW tensors, w/ fast apex or cast norm if available
180+
181+
NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
182+
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
183+
like https://github.com/pytorch/pytorch/pull/150576 lands.
177184
"""
178185
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine', '_fast_norm']
179186
normalized_shape: Tuple[int, ...]
@@ -205,14 +212,12 @@ def reset_parameters(self) -> None:
205212
nn.init.ones_(self.weight)
206213

207214
def forward(self, x: torch.Tensor) -> torch.Tensor:
208-
x = x.permute(0, 2, 3, 1)
209215
# NOTE fast norm fallback needs our rms norm impl, so both paths through here.
210216
# Since there is no built-in PyTorch impl, always use APEX RmsNorm if is installed.
211217
if self._fast_norm:
212-
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
218+
x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
213219
else:
214-
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
215-
x = x.permute(0, 3, 1, 2)
220+
x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
216221
return x
217222

218223

timm/layers/norm_act.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,15 @@
2020
from torchvision.ops.misc import FrozenBatchNorm2d
2121

2222
from .create_act import create_act_layer
23-
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm
23+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, rms_norm2d, fast_rms_norm2d
24+
from .norm import RmsNorm, RmsNorm2d
2425
from .trace_utils import _assert
2526

27+
try:
28+
from torch.nn.functional import rms_norm
29+
except ImportError:
30+
from .fast_norm import rms_norm
31+
2632

2733
def _create_act(act_layer, act_kwargs=None, inplace=False, apply_act=True):
2834
act_kwargs = act_kwargs or {}
@@ -460,3 +466,69 @@ def forward(self, x):
460466
x = self.drop(x)
461467
x = self.act(x)
462468
return x
469+
470+
471+
class RmsNormAct(RmsNorm):
472+
""" RMSNorm + Activation for '2D' NCHW tensors
473+
474+
NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
475+
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
476+
like https://github.com/pytorch/pytorch/pull/150576 lands.
477+
"""
478+
def __init__(
479+
self,
480+
num_channels,
481+
eps=1e-6,
482+
affine=True,
483+
apply_act=True,
484+
act_layer=nn.ReLU,
485+
act_kwargs=None,
486+
inplace=True,
487+
drop_layer=None,
488+
):
489+
super().__init__(channels=num_channels, eps=eps, affine=affine)
490+
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
491+
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
492+
self._fast_norm = is_fast_norm()
493+
494+
def forward(self, x: torch.Tensor) -> torch.Tensor:
495+
if self._fast_norm:
496+
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
497+
else:
498+
x = rms_norm(x, self.normalized_shape, self.weight, self.eps)
499+
x = self.drop(x)
500+
x = self.act(x)
501+
return x
502+
503+
504+
class RmsNormAct2d(RmsNorm2d):
505+
""" RMSNorm + Activation for '2D' NCHW tensors
506+
507+
NOTE: It's currently (2025-05-10) faster to use an eager 2d kernel that does reduction
508+
on dim=1 than to permute and use internal PyTorch F.rms_norm, this may change if something
509+
like https://github.com/pytorch/pytorch/pull/150576 lands.
510+
"""
511+
def __init__(
512+
self,
513+
num_channels,
514+
eps=1e-6,
515+
affine=True,
516+
apply_act=True,
517+
act_layer=nn.ReLU,
518+
act_kwargs=None,
519+
inplace=True,
520+
drop_layer=None,
521+
):
522+
super().__init__(channels=num_channels, eps=eps, affine=affine)
523+
self.drop = drop_layer() if drop_layer is not None else nn.Identity()
524+
self.act = _create_act(act_layer, act_kwargs=act_kwargs, inplace=inplace, apply_act=apply_act)
525+
self._fast_norm = is_fast_norm()
526+
527+
def forward(self, x: torch.Tensor) -> torch.Tensor:
528+
if self._fast_norm:
529+
x = fast_rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
530+
else:
531+
x = rms_norm2d(x, self.normalized_shape, self.weight, self.eps)
532+
x = self.drop(x)
533+
x = self.act(x)
534+
return x

timm/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .metaformer import *
4141
from .mlp_mixer import *
4242
from .mobilenetv3 import *
43+
from .mobilenetv5 import *
4344
from .mobilevit import *
4445
from .mvitv2 import *
4546
from .naflexvit import *
@@ -129,6 +130,7 @@
129130
load_model_config_from_hf as load_model_config_from_hf,
130131
load_state_dict_from_hf as load_state_dict_from_hf,
131132
push_to_hf_hub as push_to_hf_hub,
133+
save_for_hf as save_for_hf,
132134
)
133135
from ._manipulate import (
134136
model_parameters as model_parameters,

timm/models/_efficientnet_blocks.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,12 +517,13 @@ def __init__(
517517
value_dim=value_dim,
518518
query_strides=query_strides,
519519
kv_stride=kv_stride,
520+
dw_kernel_size=dw_kernel_size,
520521
dilation=dilation,
521522
padding=pad_type,
522-
dw_kernel_size=dw_kernel_size,
523523
attn_drop=attn_drop,
524524
proj_drop=proj_drop,
525-
#bias=use_bias, # why not here if used w/ mhsa?
525+
norm_layer=norm_layer,
526+
# use_bias=use_bias, # why not here if used w/ mhsa?
526527
)
527528
else:
528529
self.attn = Attention2d(

0 commit comments

Comments
 (0)