Skip to content

Commit 5cbf18e

Browse files
committed
feat(modules): add tformer patch emb & mlp blocks
1 parent e4ae5e0 commit 5cbf18e

File tree

2 files changed

+365
-0
lines changed

2 files changed

+365
-0
lines changed

cellseg_models_pytorch/modules/mlp.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from .base_modules import Activation, Norm
7+
8+
__all__ = ["Mlp", "MlpBlock"]
9+
10+
11+
class Mlp(nn.Module):
12+
def __init__(
13+
self,
14+
in_channels: int,
15+
mlp_ratio: int = 4,
16+
activation: str = "star_relu",
17+
dropout: float = 0.0,
18+
bias: bool = False,
19+
out_channels: int = None,
20+
**kwargs
21+
) -> None:
22+
"""MLP token mixer.
23+
24+
- MetaFormer: https://arxiv.org/abs/2210.13452
25+
- MLP-Mixer: https://arxiv.org/abs/2105.01601
26+
27+
- Input shape: (B, N, embed_dim)
28+
- Output shape: (B, seq_len, embed_dim)
29+
30+
Parameters
31+
----------
32+
in_channels : int
33+
Number of input features.
34+
mlp_ratio : int, default=4
35+
Scaling factor to get the number hidden features from the `in_features`.
36+
activation : str, default="star_relu"
37+
The name of the activation function.
38+
dropout : float, default=0.0
39+
Dropout ratio.
40+
bias : bool, default=False
41+
Flag whether to use bias terms in the nn.Linear modules.
42+
out_channels : int, optional
43+
Number of out channels. If None `out_channels = in_channels`
44+
"""
45+
super().__init__()
46+
self.out_channels = in_channels if out_channels is None else out_channels
47+
hidden_channels = int(mlp_ratio * in_channels)
48+
49+
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias)
50+
self.act = Activation(activation)
51+
self.drop1 = nn.Dropout(dropout)
52+
self.fc2 = nn.Linear(hidden_channels, self.out_channels, bias=bias)
53+
self.drop2 = nn.Dropout(dropout)
54+
55+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
56+
"""Forward pass of the MLP token mixer."""
57+
x = self.fc1(x)
58+
x = self.act(x)
59+
x = self.drop1(x)
60+
x = self.fc2(x)
61+
x = self.drop2(x)
62+
63+
return x
64+
65+
66+
class MlpBlock(nn.Module):
67+
def __init__(
68+
self,
69+
in_channels: int,
70+
mlp_ratio: int = 4,
71+
activation: str = "star_relu",
72+
dropout: float = 0.0,
73+
bias: bool = False,
74+
normalization: str = "ln",
75+
norm_kwargs: Dict[str, Any] = None,
76+
) -> None:
77+
"""Residual Mlp block.
78+
79+
I.e. norm -> mlp -> residual
80+
81+
Parameters
82+
----------
83+
in_channels : int
84+
Number of input features.
85+
mlp_ratio : int, default=4
86+
Scaling factor to get the number hidden features from the `in_features`.
87+
activation : str, default="star_relu"
88+
The name of the activation function.
89+
dropout : float, default=0.0
90+
Dropout ratio.
91+
bias : bool, default=False
92+
Flag whether to use bias terms in the nn.Linear modules.
93+
normalization : str, default="ln"
94+
The name of the normalization method.
95+
One of: "bn", "bcn", "gn", "in", "ln", "lrn", None
96+
norm_kwargs : Dict[str, Any], optional
97+
key-word args for the normalization layer. Ignored if normalization
98+
is None.
99+
"""
100+
super().__init__()
101+
self.norm = Norm(normalization, **norm_kwargs)
102+
self.mlp = Mlp(
103+
in_channels=in_channels,
104+
mlp_ratio=mlp_ratio,
105+
activation=activation,
106+
dropout=dropout,
107+
bias=bias,
108+
)
109+
110+
def forward(self, x: torch.Tensor) -> torch.Tensor:
111+
"""Forward pass of the Metaformer Mlp-block."""
112+
residual = x
113+
114+
x = self.norm(x)
115+
x = self.mlp(x)
116+
117+
return x + residual
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
import torch.nn as nn
5+
6+
from .base_modules import Norm
7+
8+
__all__ = ["ContiguousEmbed", "PatchEmbed"]
9+
10+
11+
class ContiguousEmbed(nn.Module):
12+
def __init__(
13+
self,
14+
in_channels: int,
15+
patch_size: int = 1,
16+
stride: int = 1,
17+
kernel_size: int = None,
18+
pad: int = 0,
19+
head_dim: int = 64,
20+
num_heads: int = 8,
21+
flatten: bool = True,
22+
normalization: str = None,
23+
norm_kwargs: Dict[str, Any] = None,
24+
**kwargs
25+
) -> None:
26+
"""Patch an image with nn.Conv2d and then embed.
27+
28+
NOTE:
29+
The input is patched via nn.Conv2d i.e the patch dimensions are defined by the
30+
convolution parameters. The default values are set such that every pixel value
31+
is a patch. For big inputs this results in OOM errors when computing attention.
32+
33+
If there is a need for bigger patches with no overlap, you can set for example
34+
`patch_size = 16` and `stride = 16` to get patches of size 16**2.
35+
36+
- Input shape: (B, C, H, W)
37+
- Output shape: (B, H'*W', head_dim*num_heads)
38+
(If `patch_size=1` & `stride=1` -> H'*W'=H*W).
39+
40+
NOTE: Optional normalization of the input before patching and projecting.
41+
42+
Parameters
43+
----------
44+
in_channels : int
45+
Number of input channels in the input tensor. (3 for RGB).
46+
patch_size : int, default=1
47+
Size of the patch. Defaults to 1, meaning that every pixel is a patch.
48+
(Given that stride is equal to 1.) If `kernel_size` is given, this will
49+
be ignored.
50+
stride : int, default=1
51+
The sliding window stride. Defaults to 1, meaning that every pixel is a
52+
patch. (Given that patch_size is equal to 1).
53+
kernel_size : int, optional
54+
The kernel size for the convolution. If None, the `patch_size` is used.
55+
pad : int, default=0
56+
Size of the padding.
57+
head_dim : int, default=64
58+
Number of channels per each head.
59+
num_heads : int, default=8
60+
Number of heads in multi-head self-attention.
61+
flatten : bool, default=True
62+
If True, the output will be flattened to a sequence.
63+
normalization : str, optional
64+
The name of the normalization method.
65+
One of: "bn", "bcn", "gn", "in", "ln", "lrn", None
66+
**norm_kwargs : Dict[str, Any]
67+
key-word args for the normalization layer. Ignored if normalization
68+
is None.
69+
70+
Examples
71+
--------
72+
>>> x = torch.rand([1, 3, 256, 256])
73+
74+
>>> # per-pixel patches of shape 256*256
75+
>>> conv_patch = ContiguousEmbed(
76+
in_channels=3,
77+
patch_size=1,
78+
stride=1,
79+
)
80+
>>> print(conv_patch(x).shape)
81+
>>> # torch.Size([1, 65536, 512])
82+
83+
>>> # 16*16 patches
84+
>>> conv_patch2 = ContiguousEmbed(
85+
in_channels=3,
86+
patch_size=16,
87+
stride=16,
88+
)
89+
>>> print(conv_patch2(x).shape)
90+
>>> # torch.Size([1, 256, 512])
91+
92+
>>> # Downsampling input to patches of shape 64*64
93+
>>> conv_patch3 = ContiguousEmbed(
94+
in_channels=3,
95+
stride=4,
96+
kernel_size=7,
97+
pad=2
98+
)
99+
>>> print(conv_patch3(x).shape)
100+
>>> # torch.Size([1, 4096, 512])
101+
"""
102+
super().__init__()
103+
self.flatten = flatten
104+
self.proj_dim = head_dim * num_heads
105+
self.kernel_size = patch_size if kernel_size is None else kernel_size
106+
self.pad = pad
107+
self.stride = stride
108+
109+
self.norm = Norm(normalization, **norm_kwargs)
110+
self.proj = nn.Conv2d(
111+
in_channels,
112+
self.proj_dim,
113+
kernel_size=self.kernel_size,
114+
stride=self.stride,
115+
padding=self.pad,
116+
)
117+
118+
def get_patch_size(self, img_size: int) -> int:
119+
"""Get the patch size from the conv params."""
120+
return int(
121+
(((img_size + 2 * self.pad - (self.kernel_size - 1)) - 1) / self.stride) + 1
122+
)
123+
124+
def forward(self, x: torch.Tensor) -> torch.Tensor:
125+
"""Forward pass for projection."""
126+
B, _, H, W = x.shape
127+
128+
# 1. Normalize
129+
x = self.norm(x)
130+
131+
# 2. Patch and project.
132+
x = self.proj(x) # (B, proj_dim, H', W')
133+
134+
# 3. reshape to a sequence.
135+
# Every patch has been projected into a `proj_dim` long vector.
136+
if self.flatten:
137+
p_H = self.get_patch_size(H)
138+
p_W = self.get_patch_size(W)
139+
140+
# flatten
141+
x = x.permute(0, 2, 3, 1).reshape(
142+
B, p_H * p_W, self.proj_dim
143+
) # (B, H'*W', proj_dim)
144+
145+
return x
146+
147+
148+
class PatchEmbed(nn.Module):
149+
def __init__(
150+
self,
151+
in_channels: int,
152+
patch_size: int = 16,
153+
head_dim: int = 64,
154+
num_heads: int = 8,
155+
normalization: int = None,
156+
**norm_kwargs
157+
) -> None:
158+
"""Patch an input image and then embed/project.
159+
160+
NOTE: This implementation first patches the input image by reshaping it
161+
and then embeds/projects it with nn.Linear.
162+
163+
NOTE: Optional normalization of the input before patching and projecting.
164+
165+
- Input shape: (B, C, H, W)
166+
- Patched shape: (B, H//patch_size * W//patch_size, C*patch_size**2)
167+
- Embedded output shape: (B, H//patch_size * W//patch_size, head_dim*num_heads)
168+
169+
Parameters
170+
----------
171+
in_channels : int
172+
Number of input channels in the input tensor.
173+
patch_size : int, default=16
174+
The H and W size of the patch.
175+
head_dim : int, default=64
176+
Number of channels per each head.
177+
num_heads : int, default=8
178+
Number of heads in multi-head self-attention.
179+
normalization : str, optional
180+
The name of the normalization method.
181+
One of: "bn", "bcn", "gn", "in", "ln", "lrn", None
182+
**norm_kwargs : Dict[str, Any]
183+
key-word args for the normalization layer. Ignored if normalization
184+
is None.
185+
186+
Examples
187+
--------
188+
>>> x = torch.rand([1, 3, 256, 256])
189+
190+
>>> # patches of shape 16*16
191+
>>> lin_patch = PatchEmbed(
192+
in_channels=3,
193+
patch_size=16,
194+
)
195+
>>> print(lin_patch(x).shape)
196+
>>> # torch.Size([1, 256, 512])
197+
198+
"""
199+
super().__init__()
200+
self.proj_dim = head_dim * num_heads
201+
self.patch_size = patch_size
202+
self.norm = Norm(normalization, **norm_kwargs)
203+
self.proj = nn.Linear(in_channels * (patch_size**2), self.proj_dim)
204+
205+
def img2patch(self, x: torch.Tensor) -> torch.Tensor:
206+
"""Patch an input image of shape (B, C, H, W).
207+
208+
Adapted from: PyTorch Lightning ViT tutorial.
209+
210+
Parameters
211+
----------
212+
x : torch.Tensor
213+
Input image of shape (B, C, H, W).
214+
215+
Returns
216+
-------
217+
torch.Tensor:
218+
Patched and flattened input image.
219+
Shape: (B, H//patch_size * W//patch_size, C*patch_size**2)
220+
"""
221+
B, C, H, W = x.shape
222+
x = x.reshape(
223+
B,
224+
C,
225+
H // self.patch_size,
226+
self.patch_size,
227+
W // self.patch_size,
228+
self.patch_size,
229+
) # (B, C, H', patch_size, W', patch_size)
230+
231+
x = x.permute(0, 2, 4, 1, 3, 5) # (B, H', W', C, p_H, p_W)
232+
x = x.flatten(1, 2) # (B, H'*W', C, p_H, p_W)
233+
x = x.flatten(2, 4) # (B, H'*W', C*p_H*p_W)
234+
235+
return x
236+
237+
def forward(self, x: torch.Tensor) -> torch.Tensor:
238+
"""Forward patch embedding."""
239+
# 1. Normalize
240+
x = self.norm(x) # (B, C, H, W)
241+
242+
# 2. Patch
243+
x = self.img2patch(x) # (B, H//patch_size * W//patch_size, C*patch_size**2)
244+
245+
# 3. Project/Embed
246+
x = self.proj(x)
247+
248+
return x

0 commit comments

Comments
 (0)