Skip to content

Commit ad90d6d

Browse files
committed
feat(modules): add conv-mlp head for metaformers
1 parent c05021e commit ad90d6d

File tree

1 file changed

+85
-8
lines changed
  • cellseg_models_pytorch/modules

1 file changed

+85
-8
lines changed

cellseg_models_pytorch/modules/mlp.py

Lines changed: 85 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from .base_modules import Activation, Norm
77

8-
__all__ = ["Mlp", "MlpBlock"]
8+
__all__ = ["Mlp", "ConvMlp", "MlpBlock"]
99

1010

1111
class Mlp(nn.Module):
@@ -17,7 +17,8 @@ def __init__(
1717
dropout: float = 0.0,
1818
bias: bool = False,
1919
out_channels: int = None,
20-
**act_kwargs
20+
act_kwargs: Dict[str, Any] = None,
21+
**kwargs,
2122
) -> None:
2223
"""MLP token mixer.
2324
@@ -32,7 +33,7 @@ def __init__(
3233
in_channels : int
3334
Number of input features.
3435
mlp_ratio : int, default=2
35-
Scaling factor to get the number hidden features from the `in_features`.
36+
Scaling factor to get the number hidden features from the `in_channels`.
3637
activation : str, default="star_relu"
3738
The name of the activation function.
3839
dropout : float, default=0.0
@@ -41,10 +42,11 @@ def __init__(
4142
Flag whether to use bias terms in the nn.Linear modules.
4243
out_channels : int, optional
4344
Number of out channels. If None `out_channels = in_channels`
44-
**act_kwargs:
45+
act_kwargs : Dict[str, Any], optional
4546
Arbitrary key-word arguments for the activation function.
4647
"""
4748
super().__init__()
49+
act_kwargs = act_kwargs if act_kwargs is not None else {}
4850
self.out_channels = in_channels if out_channels is None else out_channels
4951
hidden_channels = int(mlp_ratio * in_channels)
5052

@@ -65,13 +67,73 @@ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
6567
return x
6668

6769

70+
class ConvMlp(nn.Module):
71+
def __init__(
72+
self,
73+
in_channels: int,
74+
mlp_ratio: int = 2,
75+
activation: str = "star_relu",
76+
dropout: float = 0.0,
77+
bias: bool = False,
78+
out_channels: int = None,
79+
act_kwargs: Dict[str, Any] = None,
80+
**kwargs,
81+
) -> None:
82+
"""Mlp layer implemented with dws convolution.
83+
84+
Input shape: (B, in_channels, H, W).
85+
Output shape: (B, out_channels, H, W).
86+
87+
Parameters
88+
----------
89+
in_channels : int
90+
Number of input features.
91+
mlp_ratio : int, default=2
92+
Scaling factor to get the number hidden features from the `in_channels`.
93+
activation : str, default="star_relu"
94+
The name of the activation function.
95+
dropout : float, default=0.0
96+
Dropout ratio.
97+
bias : bool, default=False
98+
Flag whether to use bias terms in the nn.Linear modules.
99+
out_channels : int, optional
100+
Number of out channels. If None `out_channels = in_channels`
101+
act_kwargs : Dict[str, Any], optional
102+
Arbitrary key-word arguments for the activation function.
103+
"""
104+
super().__init__()
105+
act_kwargs = act_kwargs if act_kwargs is not None else {}
106+
self.out_channels = in_channels if out_channels is None else out_channels
107+
self.hidden_channels = int(mlp_ratio * in_channels)
108+
self.fc1 = nn.Conv2d(in_channels, self.hidden_channels, 1, bias=bias)
109+
self.dwconv = nn.Conv2d(
110+
in_channels, in_channels, 3, 1, 1, bias=bias, groups=in_channels
111+
)
112+
self.act = Activation(activation, **act_kwargs)
113+
self.fc2 = nn.Conv2d(self.hidden_channels, self.out_channels, 1, bias=bias)
114+
self.drop = nn.Dropout(dropout)
115+
116+
def forward(self, x: torch.Tensor) -> torch.Tensor:
117+
"""Forward pass of conv-mlp."""
118+
x = self.fc1(x)
119+
120+
x = self.dwconv(x)
121+
x = self.act(x)
122+
x = self.drop(x)
123+
x = self.fc2(x)
124+
x = self.drop(x)
125+
126+
return x
127+
128+
68129
class MlpBlock(nn.Module):
69130
def __init__(
70131
self,
71132
in_channels: int,
133+
mlp_type: str = "linear",
72134
mlp_ratio: int = 2,
73135
activation: str = "star_relu",
74-
activation_kwargs: Dict[str, Any] = None,
136+
act_kwargs: Dict[str, Any] = None,
75137
dropout: float = 0.0,
76138
bias: bool = False,
77139
normalization: str = "ln",
@@ -85,10 +147,15 @@ def __init__(
85147
----------
86148
in_channels : int
87149
Number of input features.
150+
mlp_type : str, default="linear"
151+
Flag for either nn.Linear or nn.Conv2d mlp-layer.
152+
One of "conv", "linear".
88153
mlp_ratio : int, default=2
89-
Scaling factor to get the number hidden features from the `in_features`.
154+
Scaling factor to get the number hidden features from the `in_channels`.
90155
activation : str, default="star_relu"
91156
The name of the activation function.
157+
act_kwargs : Dict[str, Any], optional
158+
key-word args for the activation module.
92159
dropout : float, default=0.0
93160
Dropout ratio.
94161
bias : bool, default=False
@@ -101,14 +168,24 @@ def __init__(
101168
is None.
102169
"""
103170
super().__init__()
171+
allowed = ("conv", "linear")
172+
if mlp_type not in allowed:
173+
raise ValueError(
174+
f"Illegal `mlp_type` given. Got: {mlp_type}. Allowed: {allowed}."
175+
)
176+
177+
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
178+
act_kwargs = act_kwargs if act_kwargs is not None else {}
104179
self.norm = Norm(normalization, **norm_kwargs)
105-
self.mlp = Mlp(
180+
MlpHead = Mlp if mlp_type == "linear" else ConvMlp
181+
182+
self.mlp = MlpHead(
106183
in_channels=in_channels,
107184
mlp_ratio=mlp_ratio,
108185
activation=activation,
109186
dropout=dropout,
110187
bias=bias,
111-
**activation_kwargs
188+
act_kwargs=act_kwargs,
112189
)
113190

114191
def forward(self, x: torch.Tensor) -> torch.Tensor:

0 commit comments

Comments
 (0)