8
8
"""
9
9
import types
10
10
import functools
11
+ from typing import Optional
11
12
12
13
from .evo_norm import *
13
14
from .filter_response_norm import FilterResponseNormAct2d , FilterResponseNormTlu2d
14
- from .norm_act import BatchNormAct2d , GroupNormAct , LayerNormAct , LayerNormAct2d , RmsNormAct , RmsNormAct2d
15
+ from .norm_act import (
16
+ BatchNormAct2d ,
17
+ GroupNormAct ,
18
+ GroupNorm1Act ,
19
+ LayerNormAct ,
20
+ LayerNormActFp32 ,
21
+ LayerNormAct2d ,
22
+ LayerNormAct2dFp32 ,
23
+ RmsNormAct ,
24
+ RmsNormActFp32 ,
25
+ RmsNormAct2d ,
26
+ RmsNormAct2dFp32 ,
27
+ )
15
28
from .inplace_abn import InplaceAbn
29
+ from .typing import LayerType
16
30
17
31
_NORM_ACT_MAP = dict (
18
32
batchnorm = BatchNormAct2d ,
19
33
batchnorm2d = BatchNormAct2d ,
20
34
groupnorm = GroupNormAct ,
21
- groupnorm1 = functools . partial ( GroupNormAct , num_groups = 1 ) ,
35
+ groupnorm1 = GroupNorm1Act ,
22
36
layernorm = LayerNormAct ,
23
37
layernorm2d = LayerNormAct2d ,
38
+ layernormfp32 = LayerNormActFp32 ,
39
+ layernorm2dfp32 = LayerNormAct2dFp32 ,
24
40
evonormb0 = EvoNorm2dB0 ,
25
41
evonormb1 = EvoNorm2dB1 ,
26
42
evonormb2 = EvoNorm2dB2 ,
36
52
iabn = InplaceAbn ,
37
53
rmsnorm = RmsNormAct ,
38
54
rmsnorm2d = RmsNormAct2d ,
55
+ rmsnormfp32 = RmsNormActFp32 ,
56
+ rmsnorm2dfp32 = RmsNormAct2dFp32 ,
39
57
)
40
58
_NORM_ACT_TYPES = {m for n , m in _NORM_ACT_MAP .items ()}
59
+ # Reverse map from base norm layer names to norm+act layer classes
60
+ _NORM_TO_NORM_ACT_MAP = dict (
61
+ batchnorm = BatchNormAct2d ,
62
+ batchnorm2d = BatchNormAct2d ,
63
+ groupnorm = GroupNormAct ,
64
+ groupnorm1 = GroupNorm1Act ,
65
+ layernorm = LayerNormAct ,
66
+ layernorm2d = LayerNormAct2d ,
67
+ layernormfp32 = LayerNormActFp32 ,
68
+ layernorm2dfp32 = LayerNormAct2dFp32 ,
69
+ rmsnorm = RmsNormAct ,
70
+ rmsnorm2d = RmsNormAct2d ,
71
+ rmsnormfp32 = RmsNormActFp32 ,
72
+ rmsnorm2dfp32 = RmsNormAct2dFp32 ,
73
+ )
41
74
# has act_layer arg to define act type
42
75
_NORM_ACT_REQUIRES_ARG = {
43
76
BatchNormAct2d ,
44
77
GroupNormAct ,
78
+ GroupNorm1Act ,
45
79
LayerNormAct ,
46
80
LayerNormAct2d ,
81
+ LayerNormActFp32 ,
82
+ LayerNormAct2dFp32 ,
47
83
FilterResponseNormAct2d ,
48
84
InplaceAbn ,
49
85
RmsNormAct ,
50
86
RmsNormAct2d ,
87
+ RmsNormActFp32 ,
88
+ RmsNormAct2dFp32 ,
51
89
}
52
90
53
91
54
- def create_norm_act_layer (layer_name , num_features , act_layer = None , apply_act = True , jit = False , ** kwargs ):
92
+ def create_norm_act_layer (
93
+ layer_name : LayerType ,
94
+ num_features : int ,
95
+ act_layer : Optional [LayerType ] = None ,
96
+ apply_act : bool = True ,
97
+ jit : bool = False ,
98
+ ** kwargs ,
99
+ ):
55
100
layer = get_norm_act_layer (layer_name , act_layer = act_layer )
56
101
layer_instance = layer (num_features , apply_act = apply_act , ** kwargs )
57
102
if jit :
58
103
layer_instance = torch .jit .script (layer_instance )
59
104
return layer_instance
60
105
61
106
62
- def get_norm_act_layer (norm_layer , act_layer = None ):
107
+ def get_norm_act_layer (
108
+ norm_layer : LayerType ,
109
+ act_layer : Optional [LayerType ] = None ,
110
+ ):
63
111
if norm_layer is None :
64
112
return None
65
113
assert isinstance (norm_layer , (type , str , types .FunctionType , functools .partial ))
@@ -82,26 +130,16 @@ def get_norm_act_layer(norm_layer, act_layer=None):
82
130
# if function type, must be a lambda/fn that creates a norm_act layer
83
131
norm_act_layer = norm_layer
84
132
else :
133
+ # Use reverse map to find the corresponding norm+act layer
85
134
type_name = norm_layer .__name__ .lower ()
86
- if type_name .startswith ('batchnorm' ):
87
- norm_act_layer = BatchNormAct2d
88
- elif type_name .startswith ('groupnorm' ):
89
- norm_act_layer = GroupNormAct
90
- elif type_name .startswith ('groupnorm1' ):
91
- norm_act_layer = functools .partial (GroupNormAct , num_groups = 1 )
92
- elif type_name .startswith ('layernorm2d' ):
93
- norm_act_layer = LayerNormAct2d
94
- elif type_name .startswith ('layernorm' ):
95
- norm_act_layer = LayerNormAct
96
- elif type_name .startswith ('rmsnorm2d' ):
97
- norm_act_layer = RmsNormAct2d
98
- else :
99
- assert False , f"No equivalent norm_act layer for { type_name } "
135
+ norm_act_layer = _NORM_TO_NORM_ACT_MAP .get (type_name , None )
136
+ assert norm_act_layer is not None , f"No equivalent norm_act layer for { type_name } "
100
137
101
138
if norm_act_layer in _NORM_ACT_REQUIRES_ARG :
102
139
# pass `act_layer` through for backwards compat where `act_layer=None` implies no activation.
103
140
# In the future, may force use of `apply_act` with `act_layer` arg bound to relevant NormAct types
104
141
norm_act_kwargs .setdefault ('act_layer' , act_layer )
105
142
if norm_act_kwargs :
106
143
norm_act_layer = functools .partial (norm_act_layer , ** norm_act_kwargs ) # bind/rebind args
144
+
107
145
return norm_act_layer
0 commit comments