Skip to content

Commit 6e97ec5

Browse files
committed
feat(modules): add metaformers, token-mixers
1 parent 23a3c59 commit 6e97ec5

File tree

5 files changed

+489
-0
lines changed

5 files changed

+489
-0
lines changed

cellseg_models_pytorch/modules/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
from .base_modules import Activation, Conv, Identity, Norm, Up
33
from .conv_block import ConvBlock
44
from .conv_layer import ConvLayer
5+
from .metaformer import MetaFormer
56
from .misc_modules import ChannelPool
67
from .self_attention_modules import SelfAttention, SelfAttentionBlock
8+
from .token_mixers import TokenMixer, TokenMixerBlock
79
from .transformers import Transformer2D, TransformerLayer
810

911
__all__ = [
@@ -20,4 +22,7 @@
2022
"TransformerLayer",
2123
"Transformer2D",
2224
"ChannelPool",
25+
"MetaFormer",
26+
"TokenMixer",
27+
"TokenMixerBlock",
2328
]
Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
from typing import Any, Dict
2+
3+
import torch
4+
import torch.nn as nn
5+
import torch.nn.functional as F
6+
from timm.models.layers import DropPath
7+
8+
from .base_modules import Identity
9+
from .misc_modules import LayerScale
10+
from .mlp import MlpBlock
11+
from .patch_embeddings import ContiguousEmbed
12+
from .token_mixers import RESHAPE_LOOKUP, TokenMixerBlock
13+
14+
15+
class MetaFormer(nn.Module):
16+
def __init__(
17+
self,
18+
in_channels: int,
19+
embed_kwargs: Dict[str, Any],
20+
mixer_kwargs: Dict[str, Any],
21+
mlp_kwargs: Dict[str, Any],
22+
out_channels: int = None,
23+
layer_scale: bool = False,
24+
dropout: float = 0.0,
25+
**kwargs
26+
) -> None:
27+
"""Create a generic Metaformer block with any token-mixer available.
28+
29+
Input shape: (B, in_channels, H, W)
30+
Output shape: (B, out_channels, H, W)
31+
32+
Parameters
33+
----------
34+
in_channels : int
35+
Number of input channels.
36+
embed_kwargs : Dict[str, Any]
37+
Key-word arguments for the patch embedding block.
38+
mixer_kwargs : Dict[str, Any]
39+
Key-word arguments for the token-mixer block.
40+
mlp_kwargs : Dict[str, Any]
41+
Key-word arguments for the final Mlp-block.
42+
out_channels : int, optional
43+
Number of output channels.
44+
layer_scale : bool, default=False
45+
Flag, whether to use layer-scaling.
46+
dropout : float, default=0.0
47+
Drop-path probaility.
48+
49+
Examples
50+
--------
51+
MetaFormer with exact memory-efficient self-attention:
52+
>>> import torch
53+
>>> import torch.nn as nn
54+
55+
>>> in_channels = 3
56+
>>> head_dim = 64
57+
>>> num_heads = 8
58+
>>> query_dim = head_dim*num_heads
59+
60+
>>> # patch embedding kwargs
61+
>>> embed_kwargs = {
62+
"in_channels": 3,
63+
"kernel_size": 7,
64+
"stride": 4,
65+
"pad": 2,
66+
"head_dim": head_dim,
67+
"num_heads": num_heads,
68+
}
69+
70+
>>> # token-mixer kwargs
71+
>>> mixer_kwargs = {
72+
"token_mixer": "self-attention",
73+
"normalization": "ln",
74+
"residual": True,
75+
"norm_kwargs": {
76+
"normalized_shape": query_dim
77+
},
78+
"mixer_kwargs": {
79+
"query_dim": query_dim,
80+
"name": "exact",
81+
"how": "memeff",
82+
"cross_attention_dim": None,
83+
}
84+
}
85+
86+
>>> # mlp-kwargs
87+
>>> mlp_kwargs = {
88+
"in_channels": query_dim,
89+
"norm_kwargs": {"normalized_shape": query_dim}
90+
}
91+
92+
>>> # init metaformer
93+
>>> metaformer = MetaFormer(
94+
in_channels=in_channels,
95+
embed_kwargs=embed_kwargs,
96+
mixer_kwargs=mixer_kwargs,
97+
mlp_kwargs=mlp_kwargs,
98+
layer_scale=True,
99+
dropout=0.1
100+
)
101+
102+
>>> x = torch.rand([8, 3, 256, 256])
103+
>>> print(metaformer(x).shape)
104+
>>> # torch.Size([8, 4096, 512])
105+
106+
107+
MetaFormer with multi-scale convolutional attention.:
108+
>>> import torch
109+
>>> import torch.nn as nn
110+
111+
>>> in_channels = 3
112+
>>> head_dim = 64
113+
>>> num_heads = 8
114+
>>> query_dim = head_dim*num_heads
115+
>>> out_channels = 128
116+
117+
>>> # patch embedding kwargs
118+
>>> embed_kwargs = {
119+
"in_channels": 3,
120+
"kernel_size": 7,
121+
"stride": 4,
122+
"pad": 2,
123+
"head_dim": head_dim,
124+
"num_heads": num_heads,
125+
}
126+
127+
>>> # token-mixer kwargs
128+
>>> mixer_kwargs = {
129+
"token_mixer": "mscan",
130+
"normalization": "bn",
131+
"norm_kwargs": {
132+
"num_features": query_dim,
133+
},
134+
"mixer_kwargs":{
135+
"in_channels": query_dim,
136+
}
137+
}
138+
139+
>>> # mlp-kwargs
140+
>>> mlp_kwargs = {
141+
"in_channels": query_dim,
142+
"norm_kwargs": {"normalized_shape": query_dim}
143+
}
144+
145+
>>> # init metaformer
146+
>>> metaformer = MetaFormer(
147+
in_channels=in_channels,
148+
out_channels=out_channels,
149+
embed_kwargs=embed_kwargs,
150+
mixer_kwargs=mixer_kwargs,
151+
mlp_kwargs=mlp_kwargs,
152+
layer_scale=True,
153+
dropout=0.1
154+
)
155+
156+
>>> x = torch.rand([8, 3, 256, 256])
157+
>>> print(metaformer(x).shape)
158+
>>> # torch.Size([8, 128, 256, 256])
159+
"""
160+
super().__init__()
161+
self.out_channels = out_channels if out_channels is not None else in_channels
162+
mixer_name = mixer_kwargs["token_mixer"]
163+
164+
self.patch_embed = ContiguousEmbed(
165+
**embed_kwargs, flatten=not RESHAPE_LOOKUP[mixer_name]
166+
)
167+
self.proj_dim = self.patch_embed.proj_dim
168+
169+
self.mixer = TokenMixerBlock(**mixer_kwargs)
170+
self.mlp = MlpBlock(**mlp_kwargs)
171+
self.ls1 = (
172+
LayerScale(dim=mlp_kwargs["in_channels"]) if layer_scale else Identity()
173+
)
174+
self.ls2 = (
175+
LayerScale(dim=mlp_kwargs["in_channels"]) if layer_scale else Identity()
176+
)
177+
self.drop_path1 = DropPath() if dropout else Identity()
178+
self.drop_path2 = DropPath() if dropout else Identity()
179+
180+
self.proj_out = nn.Conv2d(
181+
self.proj_dim, self.out_channels, kernel_size=1, stride=1, padding=0
182+
)
183+
184+
self.downsample = Identity()
185+
if self.out_channels is not None:
186+
self.downsample = nn.Conv2d(in_channels, out_channels, 1)
187+
188+
def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
189+
"""Forward pass of the token mixer module."""
190+
B, _, H, W = x.shape
191+
residual = self.downsample(x)
192+
193+
# 1. embed and project
194+
x = self.patch_embed(x)
195+
196+
# 2. token-mixing
197+
x = self.drop_path1(self.ls1(self.mixer(x, **kwargs)))
198+
199+
# 3. mlp
200+
x = self.drop_path2(self.ls2(self.mlp(x)))
201+
202+
# 4. Reshape back to image-like shape.
203+
p_H = self.patch_embed.get_patch_size(H)
204+
p_W = self.patch_embed.get_patch_size(W)
205+
x = x.reshape(B, p_H, p_W, self.proj_dim).permute(0, 3, 1, 2)
206+
207+
# Upsample to input dims if patch size less than orig inp size
208+
# assumes that the input is square mat.
209+
# NOTE: the kernel_size, pad, & stride has to be set correctly for this to work
210+
if p_H < H:
211+
scale_factor = H // p_H
212+
x = F.interpolate(x, scale_factor=int(scale_factor), mode="bilinear")
213+
214+
# 5. project to original input channels
215+
x = self.proj_out(x)
216+
217+
# 6. residual
218+
return x + residual
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import pytest
2+
import torch
3+
4+
from cellseg_models_pytorch.modules import MetaFormer
5+
6+
7+
@pytest.mark.parametrize("type", ["pool", "mscan", "mlp", "self-attention"])
8+
def test_metaformer(type):
9+
in_channels = 3
10+
head_dim = 64
11+
num_heads = 8
12+
query_dim = head_dim * num_heads
13+
out_channels = 16
14+
15+
embed_kwargs = {
16+
"in_channels": 3,
17+
"kernel_size": 7,
18+
"stride": 4,
19+
"pad": 2,
20+
"head_dim": head_dim,
21+
"num_heads": num_heads,
22+
}
23+
24+
if type == "self-attention":
25+
mixer_kwargs = {
26+
"token_mixer": "self-attention",
27+
"normalization": "ln",
28+
"residual": True,
29+
"norm_kwargs": {"normalized_shape": query_dim},
30+
"mixer_kwargs": {
31+
"query_dim": query_dim,
32+
"name": "exact",
33+
"how": "basic",
34+
"cross_attention_dim": None,
35+
},
36+
}
37+
38+
elif type == "mlp":
39+
mixer_kwargs = {
40+
"token_mixer": "mlp",
41+
"normalization": "ln",
42+
"norm_kwargs": {"normalized_shape": query_dim},
43+
"mixer_kwargs": {
44+
"in_channels": query_dim,
45+
},
46+
}
47+
48+
elif type in ("pool", "mscan"):
49+
mixer_kwargs = {
50+
"token_mixer": type,
51+
"normalization": "bn",
52+
"norm_kwargs": {
53+
"num_features": query_dim,
54+
},
55+
"mixer_kwargs": {
56+
"kernel_size": 3,
57+
"in_channels": query_dim,
58+
},
59+
}
60+
61+
mlp_kwargs = {
62+
"in_channels": query_dim,
63+
"norm_kwargs": {"normalized_shape": query_dim},
64+
}
65+
66+
metaformer = MetaFormer(
67+
in_channels=in_channels,
68+
out_channels=out_channels,
69+
embed_kwargs=embed_kwargs,
70+
mixer_kwargs=mixer_kwargs,
71+
mlp_kwargs=mlp_kwargs,
72+
layer_scale=True,
73+
dropout=0.1,
74+
)
75+
76+
x = torch.rand([8, 3, 32, 32])
77+
dd = metaformer(x)
78+
79+
assert dd.shape == torch.Size([8, out_channels, 32, 32])

0 commit comments

Comments
 (0)