Skip to content

Commit 15406a9

Browse files
committed
Fixing RmsNorm to fix #2380 and noticed with aimv2 when comparing outputs. Still some work to do, need to look at AMP / fast mode behaviour, dispatch to torch when possible. Add SimpleNorm for 'LayerNorm w/o centering and bias'
1 parent a648a04 commit 15406a9

File tree

4 files changed

+152
-20
lines changed

4 files changed

+152
-20
lines changed

timm/layers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
from .mixed_conv2d import MixedConv2d
3535
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
3636
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
37-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d
37+
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
3838
from .norm_act import BatchNormAct2d, GroupNormAct, GroupNorm1Act, LayerNormAct, LayerNormAct2d,\
3939
SyncBatchNormAct, convert_sync_batchnorm, FrozenBatchNormAct2d, freeze_batch_norm_2d, unfreeze_batch_norm_2d
4040
from .padding import get_padding, get_same_padding, pad_same

timm/layers/fast_norm.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,22 +108,24 @@ def fast_layer_norm(
108108
return F.layer_norm(x, normalized_shape, weight, bias, eps)
109109

110110

111+
111112
def rms_norm(
112113
x: torch.Tensor,
113114
normalized_shape: List[int],
114115
weight: Optional[torch.Tensor] = None,
115116
eps: float = 1e-5,
116117
):
117118
norm_ndim = len(normalized_shape)
119+
v = x.pow(2)
118120
if torch.jit.is_scripting():
119121
# ndim = len(x.shape)
120122
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
121123
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
122124
assert norm_ndim == 1
123-
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
125+
v = torch.mean(v, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
124126
else:
125127
dims = tuple(range(-1, -norm_ndim - 1, -1))
126-
v = torch.var(x, dim=dims, keepdim=True)
128+
v = torch.mean(v, dim=dims, keepdim=True)
127129
x = x * torch.rsqrt(v + eps)
128130
if weight is not None:
129131
x = x * weight
@@ -148,3 +150,47 @@ def fast_rms_norm(
148150

149151
# fallback
150152
return rms_norm(x, normalized_shape, weight, eps)
153+
154+
155+
def simple_norm(
156+
x: torch.Tensor,
157+
normalized_shape: List[int],
158+
weight: Optional[torch.Tensor] = None,
159+
eps: float = 1e-5,
160+
):
161+
norm_ndim = len(normalized_shape)
162+
if torch.jit.is_scripting():
163+
# ndim = len(x.shape)
164+
# dims = list(range(ndim - norm_ndim, ndim)) # this doesn't work on pytorch <= 1.13.x
165+
# NOTE -ve dims cause torchscript to crash in some cases, out of options to work around
166+
assert norm_ndim == 1
167+
v = torch.var(x, dim=-1).unsqueeze(-1) # ts crashes with -ve dim + keepdim=True
168+
else:
169+
dims = tuple(range(-1, -norm_ndim - 1, -1))
170+
v = torch.var(x, dim=dims, keepdim=True)
171+
x = x * torch.rsqrt(v + eps)
172+
if weight is not None:
173+
x = x * weight
174+
return x
175+
176+
177+
def fast_simple_norm(
178+
x: torch.Tensor,
179+
normalized_shape: List[int],
180+
weight: Optional[torch.Tensor] = None,
181+
eps: float = 1e-5,
182+
) -> torch.Tensor:
183+
if torch.jit.is_scripting():
184+
# this must be by itself, cannot merge with has_apex_rmsnorm
185+
return simple_norm(x, normalized_shape, weight, eps)
186+
187+
if is_autocast_enabled(x.device.type):
188+
# normally native AMP casts LN inputs to float32
189+
# apex LN does not, this is behaving like Apex
190+
dt = get_autocast_dtype(x.device.type)
191+
x, weight = x.to(dt), weight.to(dt)
192+
193+
with torch.amp.autocast(device_type=x.device.type, enabled=False):
194+
x = simple_norm(x, normalized_shape, weight, eps)
195+
return x
196+

timm/layers/norm.py

Lines changed: 71 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import torch.nn as nn
1212
import torch.nn.functional as F
1313

14-
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm
14+
from .fast_norm import is_fast_norm, fast_group_norm, fast_layer_norm, fast_rms_norm, fast_simple_norm
1515

1616

1717
class GroupNorm(nn.GroupNorm):
@@ -190,3 +190,73 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
190190
x = fast_rms_norm(x, self.normalized_shape, self.weight, self.eps)
191191
x = x.permute(0, 3, 1, 2)
192192
return x
193+
194+
195+
class SimpleNorm(nn.Module):
196+
""" SimpleNorm (x / std(x))
197+
"""
198+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
199+
normalized_shape: Tuple[int, ...]
200+
eps: float
201+
elementwise_affine: bool
202+
203+
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
204+
factory_kwargs = {'device': device, 'dtype': dtype}
205+
super().__init__()
206+
normalized_shape = channels
207+
if isinstance(normalized_shape, numbers.Integral):
208+
# mypy error: incompatible types in assignment
209+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
210+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
211+
self.eps = eps
212+
self.elementwise_affine = affine
213+
if self.elementwise_affine:
214+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
215+
else:
216+
self.register_parameter('weight', None)
217+
218+
self.reset_parameters()
219+
220+
def reset_parameters(self) -> None:
221+
if self.elementwise_affine:
222+
nn.init.ones_(self.weight)
223+
224+
def forward(self, x: torch.Tensor) -> torch.Tensor:
225+
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
226+
return x
227+
228+
229+
class SimpleNorm2d(nn.Module):
230+
""" SimpleNorm for NCHW tensors
231+
"""
232+
__constants__ = ['normalized_shape', 'eps', 'elementwise_affine']
233+
normalized_shape: Tuple[int, ...]
234+
eps: float
235+
elementwise_affine: bool
236+
237+
def __init__(self, channels, eps=1e-6, affine=True, device=None, dtype=None) -> None:
238+
factory_kwargs = {'device': device, 'dtype': dtype}
239+
super().__init__()
240+
normalized_shape = channels
241+
if isinstance(normalized_shape, numbers.Integral):
242+
# mypy error: incompatible types in assignment
243+
normalized_shape = (normalized_shape,) # type: ignore[assignment]
244+
self.normalized_shape = tuple(normalized_shape) # type: ignore[arg-type]
245+
self.eps = eps
246+
self.elementwise_affine = affine
247+
if self.elementwise_affine:
248+
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
249+
else:
250+
self.register_parameter('weight', None)
251+
252+
self.reset_parameters()
253+
254+
def reset_parameters(self) -> None:
255+
if self.elementwise_affine:
256+
nn.init.ones_(self.weight)
257+
258+
def forward(self, x: torch.Tensor) -> torch.Tensor:
259+
x = x.permute(0, 2, 3, 1)
260+
x = fast_simple_norm(x, self.normalized_shape, self.weight, self.eps)
261+
x = x.permute(0, 3, 1, 2)
262+
return x

timm/models/convnext.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
4747
from timm.layers import trunc_normal_, AvgPool2dSame, DropPath, Mlp, GlobalResponseNormMlp, \
4848
LayerNorm2d, LayerNorm, RmsNorm2d, RmsNorm, create_conv2d, get_act_layer, get_norm_layer, make_divisible, to_ntuple
49+
from timm.layers import SimpleNorm2d, SimpleNorm
4950
from timm.layers import NormMlpClassifierHead, ClassifierHead
5051
from ._builder import build_model_with_cfg
5152
from ._features import feature_take_indices
@@ -233,6 +234,34 @@ def forward(self, x):
233234
x = self.blocks(x)
234235
return x
235236

237+
# map of norm layers with NCHW (2D) and channels last variants
238+
_NORM_MAP = {
239+
'layernorm': (LayerNorm2d, LayerNorm),
240+
'layernorm2d': (LayerNorm2d, LayerNorm),
241+
'simplenorm': (SimpleNorm2d, SimpleNorm),
242+
'simplenorm2d': (SimpleNorm2d, SimpleNorm),
243+
'rmsnorm': (RmsNorm2d, RmsNorm),
244+
'rmsnorm2d': (RmsNorm2d, RmsNorm),
245+
}
246+
247+
248+
def _get_norm_layers(norm_layer: Union[Callable, str], conv_mlp: bool, norm_eps: float):
249+
norm_layer = norm_layer or 'layernorm'
250+
if norm_layer in _NORM_MAP:
251+
norm_layer_cl = _NORM_MAP[norm_layer][0] if conv_mlp else _NORM_MAP[norm_layer][1]
252+
norm_layer = _NORM_MAP[norm_layer][0]
253+
if norm_eps is not None:
254+
norm_layer = partial(norm_layer, eps=norm_eps)
255+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
256+
else:
257+
assert conv_mlp, \
258+
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
259+
norm_layer = get_norm_layer(norm_layer)
260+
norm_layer_cl = norm_layer
261+
if norm_eps is not None:
262+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
263+
return norm_layer, norm_layer_cl
264+
236265

237266
class ConvNeXt(nn.Module):
238267
r""" ConvNeXt
@@ -289,20 +318,7 @@ def __init__(
289318
super().__init__()
290319
assert output_stride in (8, 16, 32)
291320
kernel_sizes = to_ntuple(4)(kernel_sizes)
292-
use_rms = isinstance(norm_layer, str) and norm_layer.startswith('rmsnorm')
293-
if norm_layer is None or use_rms:
294-
norm_layer = RmsNorm2d if use_rms else LayerNorm2d
295-
norm_layer_cl = norm_layer if conv_mlp else (RmsNorm if use_rms else LayerNorm)
296-
if norm_eps is not None:
297-
norm_layer = partial(norm_layer, eps=norm_eps)
298-
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
299-
else:
300-
assert conv_mlp,\
301-
'If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input'
302-
norm_layer = get_norm_layer(norm_layer)
303-
norm_layer_cl = norm_layer
304-
if norm_eps is not None:
305-
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
321+
norm_layer, norm_layer_cl = _get_norm_layers(norm_layer, conv_mlp, norm_eps)
306322
act_layer = get_act_layer(act_layer)
307323

308324
self.num_classes = num_classes
@@ -975,7 +991,7 @@ def _cfgv2(url='', **kwargs):
975991
@register_model
976992
def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
977993
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
978-
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d')
994+
model_args = dict(depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm')
979995
model = _create_convnext('convnext_zepto_rms', pretrained=pretrained, **dict(model_args, **kwargs))
980996
return model
981997

@@ -984,7 +1000,7 @@ def convnext_zepto_rms(pretrained=False, **kwargs) -> ConvNeXt:
9841000
def convnext_zepto_rms_ols(pretrained=False, **kwargs) -> ConvNeXt:
9851001
# timm femto variant (NOTE: still tweaking depths, will vary between 3-4M param, current is 3.7M
9861002
model_args = dict(
987-
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='rmsnorm2d', stem_type='overlap_act')
1003+
depths=(2, 2, 4, 2), dims=(32, 64, 128, 256), conv_mlp=True, norm_layer='simplenorm', stem_type='overlap_act')
9881004
model = _create_convnext('convnext_zepto_rms_ols', pretrained=pretrained, **dict(model_args, **kwargs))
9891005
return model
9901006

0 commit comments

Comments
 (0)