5
5
6
6
from .base_modules import Activation , Norm
7
7
8
- __all__ = ["Mlp" , "MlpBlock" ]
8
+ __all__ = ["Mlp" , "ConvMlp" , " MlpBlock" ]
9
9
10
10
11
11
class Mlp (nn .Module ):
@@ -17,7 +17,8 @@ def __init__(
17
17
dropout : float = 0.0 ,
18
18
bias : bool = False ,
19
19
out_channels : int = None ,
20
- ** act_kwargs
20
+ act_kwargs : Dict [str , Any ] = None ,
21
+ ** kwargs ,
21
22
) -> None :
22
23
"""MLP token mixer.
23
24
@@ -32,7 +33,7 @@ def __init__(
32
33
in_channels : int
33
34
Number of input features.
34
35
mlp_ratio : int, default=2
35
- Scaling factor to get the number hidden features from the `in_features `.
36
+ Scaling factor to get the number hidden features from the `in_channels `.
36
37
activation : str, default="star_relu"
37
38
The name of the activation function.
38
39
dropout : float, default=0.0
@@ -41,10 +42,11 @@ def __init__(
41
42
Flag whether to use bias terms in the nn.Linear modules.
42
43
out_channels : int, optional
43
44
Number of out channels. If None `out_channels = in_channels`
44
- ** act_kwargs:
45
+ act_kwargs : Dict[str, Any], optional
45
46
Arbitrary key-word arguments for the activation function.
46
47
"""
47
48
super ().__init__ ()
49
+ act_kwargs = act_kwargs if act_kwargs is not None else {}
48
50
self .out_channels = in_channels if out_channels is None else out_channels
49
51
hidden_channels = int (mlp_ratio * in_channels )
50
52
@@ -65,13 +67,73 @@ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
65
67
return x
66
68
67
69
70
+ class ConvMlp (nn .Module ):
71
+ def __init__ (
72
+ self ,
73
+ in_channels : int ,
74
+ mlp_ratio : int = 2 ,
75
+ activation : str = "star_relu" ,
76
+ dropout : float = 0.0 ,
77
+ bias : bool = False ,
78
+ out_channels : int = None ,
79
+ act_kwargs : Dict [str , Any ] = None ,
80
+ ** kwargs ,
81
+ ) -> None :
82
+ """Mlp layer implemented with dws convolution.
83
+
84
+ Input shape: (B, in_channels, H, W).
85
+ Output shape: (B, out_channels, H, W).
86
+
87
+ Parameters
88
+ ----------
89
+ in_channels : int
90
+ Number of input features.
91
+ mlp_ratio : int, default=2
92
+ Scaling factor to get the number hidden features from the `in_channels`.
93
+ activation : str, default="star_relu"
94
+ The name of the activation function.
95
+ dropout : float, default=0.0
96
+ Dropout ratio.
97
+ bias : bool, default=False
98
+ Flag whether to use bias terms in the nn.Linear modules.
99
+ out_channels : int, optional
100
+ Number of out channels. If None `out_channels = in_channels`
101
+ act_kwargs : Dict[str, Any], optional
102
+ Arbitrary key-word arguments for the activation function.
103
+ """
104
+ super ().__init__ ()
105
+ act_kwargs = act_kwargs if act_kwargs is not None else {}
106
+ self .out_channels = in_channels if out_channels is None else out_channels
107
+ self .hidden_channels = int (mlp_ratio * in_channels )
108
+ self .fc1 = nn .Conv2d (in_channels , self .hidden_channels , 1 , bias = bias )
109
+ self .dwconv = nn .Conv2d (
110
+ in_channels , in_channels , 3 , 1 , 1 , bias = bias , groups = in_channels
111
+ )
112
+ self .act = Activation (activation , ** act_kwargs )
113
+ self .fc2 = nn .Conv2d (self .hidden_channels , self .out_channels , 1 , bias = bias )
114
+ self .drop = nn .Dropout (dropout )
115
+
116
+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
117
+ """Forward pass of conv-mlp."""
118
+ x = self .fc1 (x )
119
+
120
+ x = self .dwconv (x )
121
+ x = self .act (x )
122
+ x = self .drop (x )
123
+ x = self .fc2 (x )
124
+ x = self .drop (x )
125
+
126
+ return x
127
+
128
+
68
129
class MlpBlock (nn .Module ):
69
130
def __init__ (
70
131
self ,
71
132
in_channels : int ,
133
+ mlp_type : str = "linear" ,
72
134
mlp_ratio : int = 2 ,
73
135
activation : str = "star_relu" ,
74
- activation_kwargs : Dict [str , Any ] = None ,
136
+ act_kwargs : Dict [str , Any ] = None ,
75
137
dropout : float = 0.0 ,
76
138
bias : bool = False ,
77
139
normalization : str = "ln" ,
@@ -85,10 +147,15 @@ def __init__(
85
147
----------
86
148
in_channels : int
87
149
Number of input features.
150
+ mlp_type : str, default="linear"
151
+ Flag for either nn.Linear or nn.Conv2d mlp-layer.
152
+ One of "conv", "linear".
88
153
mlp_ratio : int, default=2
89
- Scaling factor to get the number hidden features from the `in_features `.
154
+ Scaling factor to get the number hidden features from the `in_channels `.
90
155
activation : str, default="star_relu"
91
156
The name of the activation function.
157
+ act_kwargs : Dict[str, Any], optional
158
+ key-word args for the activation module.
92
159
dropout : float, default=0.0
93
160
Dropout ratio.
94
161
bias : bool, default=False
@@ -101,14 +168,24 @@ def __init__(
101
168
is None.
102
169
"""
103
170
super ().__init__ ()
171
+ allowed = ("conv" , "linear" )
172
+ if mlp_type not in allowed :
173
+ raise ValueError (
174
+ f"Illegal `mlp_type` given. Got: { mlp_type } . Allowed: { allowed } ."
175
+ )
176
+
177
+ norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
178
+ act_kwargs = act_kwargs if act_kwargs is not None else {}
104
179
self .norm = Norm (normalization , ** norm_kwargs )
105
- self .mlp = Mlp (
180
+ MlpHead = Mlp if mlp_type == "linear" else ConvMlp
181
+
182
+ self .mlp = MlpHead (
106
183
in_channels = in_channels ,
107
184
mlp_ratio = mlp_ratio ,
108
185
activation = activation ,
109
186
dropout = dropout ,
110
187
bias = bias ,
111
- ** activation_kwargs
188
+ act_kwargs = act_kwargs ,
112
189
)
113
190
114
191
def forward (self , x : torch .Tensor ) -> torch .Tensor :
0 commit comments