Skip to content

Commit 2b40ec8

Browse files
committed
feat(modules): add up for patch embeds in tformer
1 parent 5d712ff commit 2b40ec8

File tree

1 file changed

+57
-22
lines changed

1 file changed

+57
-22
lines changed

cellseg_models_pytorch/modules/transformers.py

Lines changed: 57 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1-
from typing import Tuple
1+
from typing import Any, Dict, Optional, Tuple
22

33
import torch
44
import torch.nn as nn
5+
import torch.nn.functional as F
56

6-
from cellseg_models_pytorch.modules import SelfAttentionBlock
7-
7+
from .base_modules import Identity
8+
from .misc_modules import LayerScale
89
from .mlp import MlpBlock
910
from .patch_embeddings import ContiguousEmbed
11+
from .self_attention_modules import SelfAttentionBlock
1012

1113
__all__ = ["Transformer2D", "TransformerLayer"]
1214

@@ -23,10 +25,12 @@ def __init__(
2325
computation_types: Tuple[str, ...] = ("basic", "basic"),
2426
dropouts: Tuple[float, ...] = (0.0, 0.0),
2527
biases: Tuple[bool, ...] = (False, False),
28+
layer_scales: Tuple[bool, ...] = (False, False),
2629
activation: str = "star_relu",
2730
num_groups: int = 32,
28-
slice_size: int = 4,
29-
mlp_ratio: int = 4,
31+
mlp_ratio: int = 2,
32+
slice_size: Optional[int] = 4,
33+
patch_embed_kwargs: Optional[Dict[str, Any]] = None,
3034
**kwargs,
3135
) -> None:
3236
"""Create a transformer for 2D-image-like (B, C, H, W) inputs.
@@ -49,7 +53,7 @@ def __init__(
4953
n_blocks : int, default=2
5054
Number of Multihead attention blocks in the transformer.
5155
block_types : Tuple[str, ...], default=("exact", "exact")
52-
The name of the SelfAttentionBlocks in the TransformerLayer.
56+
The names/types of the SelfAttentionBlocks in the TransformerLayer.
5357
Length of the tuple has to equal `n_blocks`.
5458
Allowed names: ("exact", "linformer").
5559
computation_types : Tuple[str, ...], default=("basic", "basic")
@@ -60,18 +64,23 @@ def __init__(
6064
Dropout probabilities for the SelfAttention blocks.
6165
biases : bool, default=(True, True)
6266
Include bias terms in the SelfAttention blocks.
67+
layer_scales : bool, default=(False, False)
68+
Learnable layer weights for the self-attention matrix.
6369
activation : str, default="star_relu"
6470
The activation function applied at the end of the transformer layer fc.
6571
One of ("geglu", "approximate_gelu", "star_relu").
6672
num_groups : int, default=32
6773
Number of groups in the first group-norm op before the input is
6874
projected to be suitable for self-attention.
69-
slice_size : int, default=4
75+
mlp_ratio : int, default=2
76+
Scaling factor for the number of input features to get the number of
77+
hidden features in the final `Mlp` layer of the transformer.
78+
slice_size : int, optional, default=4
7079
Slice size for sliced self-attention. This is used only if
7180
`name = "slice"` for a SelfAttentionBlock.
72-
mlp_ratio : int, default=4
73-
Multiplier that defines the out dimension of the final fc projection
74-
layer.
81+
patch_embed_kwargs: Dict[str, Any], optional
82+
Extra key-word arguments for the patch embedding module. See the
83+
`ContiguousEmbed` module for more info.
7584
"""
7685
super().__init__()
7786
patch_norm = "gn" if in_channels >= 32 else None
@@ -82,6 +91,7 @@ def __init__(
8291
num_heads=num_heads,
8392
normalization=patch_norm,
8493
norm_kwargs={"num_features": in_channels, "num_groups": num_groups},
94+
**patch_embed_kwargs if patch_embed_kwargs is not None else {},
8595
)
8696
self.proj_dim = self.patch_embed.proj_dim
8797

@@ -95,6 +105,7 @@ def __init__(
95105
computation_types=computation_types,
96106
dropouts=dropouts,
97107
biases=biases,
108+
layer_scales=layer_scales,
98109
activation=activation,
99110
slice_size=slice_size,
100111
mlp_ratio=mlp_ratio,
@@ -130,11 +141,22 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
130141
# 2. transformer
131142
x = self.transformer(x, context)
132143

133-
# 3. Reshape back to image-like shape and project to original input channels.
134-
x = x.reshape(B, H, W, self.proj_dim).permute(0, 3, 1, 2)
144+
# 3. Reshape back to image-like shape.
145+
p_H = self.patch_embed.get_patch_size(H)
146+
p_W = self.patch_embed.get_patch_size(W)
147+
x = x.reshape(B, p_H, p_W, self.proj_dim).permute(0, 3, 1, 2)
148+
149+
# Upsample to input dims if patch size less than orig inp size
150+
# assumes that the input is square mat.
151+
# NOTE: the kernel_size, pad, & stride has to be set correctly for this to work
152+
if p_H < H:
153+
scale_factor = H // p_H
154+
x = F.interpolate(x, scale_factor=scale_factor, mode="bilinear")
155+
156+
# 4. project to original input channels
135157
x = self.proj_out(x)
136158

137-
# 4. residual
159+
# 5. residual
138160
return x + residual
139161

140162

@@ -151,8 +173,9 @@ def __init__(
151173
computation_types: Tuple[str, ...] = ("basic", "basic"),
152174
dropouts: Tuple[float, ...] = (0.0, 0.0),
153175
biases: Tuple[bool, ...] = (False, False),
154-
slice_size: int = 4,
155-
mlp_ratio: int = 4,
176+
layer_scales: Tuple[bool, ...] = (False, False),
177+
mlp_ratio: int = 2,
178+
slice_size: Optional[int] = 4,
156179
**kwargs,
157180
) -> None:
158181
"""Chain transformer blocks to compose a full generic transformer.
@@ -191,12 +214,14 @@ def __init__(
191214
Dropout probabilities for the SelfAttention blocks.
192215
biases : bool, default=(True, True)
193216
Include bias terms in the SelfAttention blocks.
194-
slice_size : int, default=4
217+
layer_scales : bool, default=(False, False)
218+
Learnable layer weights for the self-attention matrix.
219+
mlp_ratio : int, default=2
220+
Scaling factor for the number of input features to get the number of
221+
hidden features in the final `Mlp` layer of the transformer.
222+
slice_size : int, optional, default=4
195223
Slice size for sliced self-attention. This is used only if
196224
`name = "slice"` for a SelfAttentionBlock.
197-
mlp_proj : int, default=4
198-
Multiplier that defines the out dimension of the final fc projection
199-
layer.
200225
**kwargs:
201226
Arbitrary key-word arguments.
202227
@@ -218,7 +243,9 @@ def __init__(
218243
f"Illegal args: {illegal_args}"
219244
)
220245

221-
self.tr_blocks = nn.ModuleDict()
246+
# self.tr_blocks = nn.ModuleDict()
247+
self.tr_blocks = nn.ModuleList()
248+
self.layer_scales = nn.ModuleList()
222249
blocks = list(range(n_blocks))
223250
for i in blocks:
224251
cross_dim = cross_attention_dim if i == blocks[-1] else None
@@ -235,7 +262,13 @@ def __init__(
235262
slice_size=slice_size,
236263
**kwargs,
237264
)
238-
self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = att_block
265+
self.tr_blocks.append(att_block)
266+
267+
# add layer scale. (Optional)
268+
ls = LayerScale(query_dim) if layer_scales[i] else Identity()
269+
self.layer_scales.append(ls)
270+
271+
# self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = tr_block
239272

240273
self.mlp = MlpBlock(
241274
in_channels=query_dim,
@@ -263,12 +296,14 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
263296
Self-attended input tensor. Shape (B, H*W, query_dim).
264297
"""
265298
n_blocks = len(self.tr_blocks)
266-
for i, tr_block in enumerate(self.tr_blocks.values(), 1):
299+
300+
for i, (tr_block, ls) in enumerate(zip(self.tr_blocks, self.layer_scales), 1):
267301
# apply context only at the last transformer block
268302
con = None
269303
if i == n_blocks:
270304
con = context
271305

272306
x = tr_block(x, con)
307+
x = ls(x)
273308

274309
return self.mlp(x) + x

0 commit comments

Comments
 (0)