Skip to content

Commit 446e8a8

Browse files
authored
Add flag to enable float32 computation for normalization (norm + affine) (#2536)
* Add flag to enable float32 computation for normalization (norm + affine) in several timm norm & norm+act layers * Change of plan, make separate 'Fp32' norm and norm+act instances to reduce change of regressions
1 parent 168137b commit 446e8a8

File tree

6 files changed

+594
-110
lines changed

6 files changed

+594
-110
lines changed

timm/layers/__init__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,34 @@
6969
from .mixed_conv2d import MixedConv2d
7070
from .mlp import Mlp, GluMlp, GatedMlp, SwiGLU, SwiGLUPacked, ConvMlp, GlobalResponseNormMlp
7171
from .non_local_attn import NonLocalAttn, BatNonLocalAttn
72-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
72+
from .norm import (
73+
GroupNorm,
74+
GroupNorm1,
75+
LayerNorm,
76+
LayerNorm2d,
77+
LayerNormFp32,
78+
LayerNorm2dFp32,
79+
RmsNorm,
80+
RmsNorm2d,
81+
RmsNormFp32,
82+
RmsNorm2dFp32,
83+
SimpleNorm,
84+
SimpleNorm2d,
85+
SimpleNormFp32,
86+
SimpleNorm2dFp32,
87+
)
7388
from .norm_act import (
7489
BatchNormAct2d,
7590
GroupNormAct,
7691
GroupNorm1Act,
7792
LayerNormAct,
7893
LayerNormAct2d,
94+
LayerNormActFp32,
95+
LayerNormAct2dFp32,
96+
RmsNormAct,
97+
RmsNormAct2d,
98+
RmsNormActFp32,
99+
RmsNormAct2dFp32,
79100
SyncBatchNormAct,
80101
convert_sync_batchnorm,
81102
FrozenBatchNormAct2d,

timm/layers/create_act.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
""" Activation Factory
22
Hacked together by / Copyright 2020 Ross Wightman
33
"""
4-
from typing import Union, Callable, Type
4+
from typing import Callable, Optional, Type, Union
55

66
from .activations import *
77
from .activations_me import *
88
from .config import is_exportable, is_scriptable
9+
from .typing import LayerType
910

1011
# PyTorch has an optimized, native 'silu' (aka 'swish') operator as of PyTorch 1.7.
1112
# Also hardsigmoid, hardswish, and soon mish. This code will use native version if present.
@@ -88,7 +89,7 @@
8889
a.setdefault('hardswish', a.get('hard_swish'))
8990

9091

91-
def get_act_fn(name: Union[Callable, str] = 'relu'):
92+
def get_act_fn(name: Optional[LayerType] = 'relu'):
9293
""" Activation Function Factory
9394
Fetching activation fns by name with this function allows export or torch script friendly
9495
functions to be returned dynamically based on current config.
@@ -106,7 +107,7 @@ def get_act_fn(name: Union[Callable, str] = 'relu'):
106107
return _ACT_FN_DEFAULT[name]
107108

108109

109-
def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
110+
def get_act_layer(name: Optional[LayerType] = 'relu'):
110111
""" Activation Layer Factory
111112
Fetching activation layers by name with this function allows export or torch script friendly
112113
functions to be returned dynamically based on current config.
@@ -125,7 +126,11 @@ def get_act_layer(name: Union[Type[nn.Module], str] = 'relu'):
125126
return _ACT_LAYER_DEFAULT[name]
126127

127128

128-
def create_act_layer(name: Union[Type[nn.Module], str], inplace=None, **kwargs):
129+
def create_act_layer(
130+
name: Optional[LayerType],
131+
inplace: Optional[bool] = None,
132+
**kwargs
133+
):
129134
act_layer = get_act_layer(name)
130135
if act_layer is None:
131136
return None

timm/layers/create_norm.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,22 @@
1010

1111
import torch.nn as nn
1212

13-
from .norm import GroupNorm, GroupNorm1, LayerNorm, LayerNorm2d, RmsNorm, RmsNorm2d, SimpleNorm, SimpleNorm2d
13+
from .norm import (
14+
GroupNorm,
15+
GroupNorm1,
16+
LayerNorm,
17+
LayerNorm2d,
18+
LayerNormFp32,
19+
LayerNorm2dFp32,
20+
RmsNorm,
21+
RmsNorm2d,
22+
RmsNormFp32,
23+
RmsNorm2dFp32,
24+
SimpleNorm,
25+
SimpleNorm2d,
26+
SimpleNormFp32,
27+
SimpleNorm2dFp32,
28+
)
1429
from torchvision.ops.misc import FrozenBatchNorm2d
1530

1631
_NORM_MAP = dict(
@@ -21,10 +36,16 @@
2136
groupnorm1=GroupNorm1,
2237
layernorm=LayerNorm,
2338
layernorm2d=LayerNorm2d,
39+
layernormfp32=LayerNormFp32,
40+
layernorm2dfp32=LayerNorm2dFp32,
2441
rmsnorm=RmsNorm,
2542
rmsnorm2d=RmsNorm2d,
43+
rmsnormfp32=RmsNormFp32,
44+
rmsnorm2dfp32=RmsNorm2dFp32,
2645
simplenorm=SimpleNorm,
2746
simplenorm2d=SimpleNorm2d,
47+
simplenormfp32=SimpleNormFp32,
48+
simplenorm2dfp32=SimpleNorm2dFp32,
2849
frozenbatchnorm2d=FrozenBatchNorm2d,
2950
)
3051
_NORM_TYPES = {m for n, m in _NORM_MAP.items()}

timm/layers/create_norm_act.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,35 @@
88
"""
99
import types
1010
import functools
11+
from typing import Optional
1112

1213
from .evo_norm import *
1314
from .filter_response_norm import FilterResponseNormAct2d, FilterResponseNormTlu2d
14-
from .norm_act import BatchNormAct2d, GroupNormAct, LayerNormAct, LayerNormAct2d, RmsNormAct, RmsNormAct2d
15+
from .norm_act import (
16+
BatchNormAct2d,
17+
GroupNormAct,
18+
GroupNorm1Act,
19+
LayerNormAct,
20+
LayerNormActFp32,
21+
LayerNormAct2d,
22+
LayerNormAct2dFp32,
23+
RmsNormAct,
24+
RmsNormActFp32,
25+
RmsNormAct2d,
26+
RmsNormAct2dFp32,
27+
)
1528
from .inplace_abn import InplaceAbn
29+
from .typing import LayerType
1630

1731
_NORM_ACT_MAP = dict(
1832
batchnorm=BatchNormAct2d,
1933
batchnorm2d=BatchNormAct2d,
2034
groupnorm=GroupNormAct,
21-
groupnorm1=functools.partial(GroupNormAct, num_groups=1),
35+
groupnorm1=GroupNorm1Act,
2236
layernorm=LayerNormAct,
2337
layernorm2d=LayerNormAct2d,
38+
layernormfp32=LayerNormActFp32,
39+
layernorm2dfp32=LayerNormAct2dFp32,
2440
evonormb0=EvoNorm2dB0,
2541
evonormb1=EvoNorm2dB1,
2642
evonormb2=EvoNorm2dB2,
@@ -36,30 +52,62 @@
3652
iabn=InplaceAbn,
3753
rmsnorm=RmsNormAct,
3854
rmsnorm2d=RmsNormAct2d,
55+
rmsnormfp32=RmsNormActFp32,
56+
rmsnorm2dfp32=RmsNormAct2dFp32,
3957
)
4058
_NORM_ACT_TYPES = {m for n, m in _NORM_ACT_MAP.items()}
59+
# Reverse map from base norm layer names to norm+act layer classes
60+
_NORM_TO_NORM_ACT_MAP = dict(
61+
batchnorm=BatchNormAct2d,
62+
batchnorm2d=BatchNormAct2d,
63+
groupnorm=GroupNormAct,
64+
groupnorm1=GroupNorm1Act,
65+
layernorm=LayerNormAct,
66+
layernorm2d=LayerNormAct2d,
67+
layernormfp32=LayerNormActFp32,
68+
layernorm2dfp32=LayerNormAct2dFp32,
69+
rmsnorm=RmsNormAct,
70+
rmsnorm2d=RmsNormAct2d,
71+
rmsnormfp32=RmsNormActFp32,
72+
rmsnorm2dfp32=RmsNormAct2dFp32,
73+
)
4174
# has act_layer arg to define act type
4275
_NORM_ACT_REQUIRES_ARG = {
4376
BatchNormAct2d,
4477
GroupNormAct,
78+
GroupNorm1Act,
4579
LayerNormAct,
4680
LayerNormAct2d,
81+
LayerNormActFp32,
82+
LayerNormAct2dFp32,
4783
FilterResponseNormAct2d,
4884
InplaceAbn,
4985
RmsNormAct,
5086
RmsNormAct2d,
87+
RmsNormActFp32,
88+
RmsNormAct2dFp32,
5189
}
5290

5391

54-
def create_norm_act_layer(layer_name, num_features, act_layer=None, apply_act=True, jit=False, **kwargs):
92+
def create_norm_act_layer(
93+
layer_name: LayerType,
94+
num_features: int,
95+
act_layer: Optional[LayerType] = None,
96+
apply_act: bool = True,
97+
jit: bool = False,
98+
**kwargs,
99+
):
55100
layer = get_norm_act_layer(layer_name, act_layer=act_layer)
56101
layer_instance = layer(num_features, apply_act=apply_act, **kwargs)
57102
if jit:
58103
layer_instance = torch.jit.script(layer_instance)
59104
return layer_instance
60105

61106

62-
def get_norm_act_layer(norm_layer, act_layer=None):
107+
def get_norm_act_layer(
108+
norm_layer: LayerType,
109+
act_layer: Optional[LayerType] = None,
110+
):
63111
if norm_layer is None:
64112
return None
65113
assert isinstance(norm_layer, (type, str, types.FunctionType, functools.partial))
@@ -82,26 +130,16 @@ def get_norm_act_layer(norm_layer, act_layer=None):
82130
# if function type, must be a lambda/fn that creates a norm_act layer
83131
norm_act_layer = norm_layer
84132
else:
133+
# Use reverse map to find the corresponding norm+act layer
85134
type_name = norm_layer.__name__.lower()
86-
if type_name.startswith('batchnorm'):
87-
norm_act_layer = BatchNormAct2d
88-
elif type_name.startswith('groupnorm'):
89-
norm_act_layer = GroupNormAct
90-
elif type_name.startswith('groupnorm1'):
91-
norm_act_layer = functools.partial(GroupNormAct, num_groups=1)
92-
elif type_name.startswith('layernorm2d'):
93-
norm_act_layer = LayerNormAct2d
94-
elif type_name.startswith('layernorm'):
95-
norm_act_layer = LayerNormAct
96-
elif type_name.startswith('rmsnorm2d'):
97-
norm_act_layer = RmsNormAct2d
98-
else:
99-
assert False, f"No equivalent norm_act layer for {type_name}"
135+
norm_act_layer = _NORM_TO_NORM_ACT_MAP.get(type_name, None)
136+
assert norm_act_layer is not None, f"No equivalent norm_act layer for {type_name}"
100137

101138
if norm_act_layer in _NORM_ACT_REQUIRES_ARG:
102139
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
103140
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
104141
norm_act_kwargs.setdefault('act_layer', act_layer)
105142
if norm_act_kwargs:
106143
norm_act_layer = functools.partial(norm_act_layer, **norm_act_kwargs) # bind/rebind args
144+
107145
return norm_act_layer

0 commit comments

Comments
 (0)