Skip to content

Commit cacaabe

Browse files
committed
fix(modules): adjust transformer layers
1 parent 5cbf18e commit cacaabe

File tree

5 files changed

+36
-56
lines changed

5 files changed

+36
-56
lines changed

cellseg_models_pytorch/decoders/decoder_stage.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -328,9 +328,9 @@ def forward(
328328
Output torch.Tensor and extra skip torch.Tensors. If no extra
329329
skips are present, returns None as the second return value.
330330
"""
331-
x = self.upsample(x)
331+
x = self.upsample(x) # (B, in_channels, H, W)
332332

333-
# long skip
333+
# long skip (B, in_channels(+skip_channels), H, W)
334334
x = self.skip(x, ix=self.stage_ix, skips=skips, extra_skips=extra_skips)
335335

336336
# unetpp returns extra skips
@@ -340,15 +340,15 @@ def forward(
340340
# conv layers
341341
if self.n_layers is not None:
342342
for conv_layer in self.conv_layers.values():
343-
x = conv_layer(x, style)
343+
x = conv_layer(x, style) # (B, out_channels, H, W)
344344

345345
# transformer layers
346346
if self.n_transformers is not None:
347347
for transformer in self.transformers.values():
348-
x = transformer(x)
348+
x = transformer(x) # (B, long_skip_channels/out_channels, H, W)
349349

350350
# channel pool if conv-layers are skipped.
351351
if self.n_layers is None:
352-
x = self.ch_pool(x)
352+
x = self.ch_pool(x) # (B, out_channels, H, W)
353353

354354
return x, extra_skips

cellseg_models_pytorch/modules/base_modules.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import torch
22
import torch.nn as nn
33

4-
from .act import ACT_LOOKUP, TR_ACT_LOOKUP
4+
from .act import ACT_LOOKUP
55
from .conv import CONV_LOOKUP
66
from .norm import NORM_LOOKUP
77
from .upsample import UP_LOOKUP
88

9-
__all__ = ["Activation", "Norm", "Up", "Conv", "Identity", "TransformerAct"]
9+
__all__ = ["Activation", "Norm", "Up", "Conv", "Identity"]
1010

1111

1212
class Identity(nn.Module):
@@ -151,32 +151,3 @@ def __init__(self, name: str, **kwargs) -> None:
151151
def forward(self, x: torch.Tensor) -> torch.Tensor:
152152
"""Forward pass for the convolution function."""
153153
return self.conv(x)
154-
155-
156-
class TransformerAct(nn.Module):
157-
def __init__(self, name: str, **kwargs) -> None:
158-
"""Activation function for transformer outputs wrapper class.
159-
160-
Parameters:
161-
-----------
162-
name : str
163-
Name of the transformer activation method.
164-
165-
Raises
166-
------
167-
ValueError: if the transformer activation name is illegal.
168-
"""
169-
super().__init__()
170-
171-
allowed = list(TR_ACT_LOOKUP.keys())
172-
if name not in allowed:
173-
raise ValueError(
174-
"Illegal transformer activation method given. "
175-
f"Allowed: {allowed}. Got: '{name}'"
176-
)
177-
178-
self.tr_act = TR_ACT_LOOKUP[name](**kwargs)
179-
180-
def forward(self, x: torch.Tensor) -> torch.Tensor:
181-
"""Forward pass for the convolution function."""
182-
return self.tr_act(x)

cellseg_models_pytorch/modules/norm/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from torch.nn import BatchNorm2d, InstanceNorm2d, SyncBatchNorm
1+
from torch.nn import BatchNorm2d, InstanceNorm2d, LayerNorm, SyncBatchNorm
22

33
from .bcn import BCNorm
44
from .gn import GroupNorm
55
from .ln import LayerNorm2d
66

77
NORM_LOOKUP = {
88
"bn": BatchNorm2d,
9-
"ln": LayerNorm2d,
9+
"ln2d": LayerNorm2d,
10+
"ln": LayerNorm,
1011
"bcn": BCNorm,
1112
"gn": GroupNorm,
1213
"in": InstanceNorm2d,
@@ -20,4 +21,5 @@
2021
"InstanceNorm2d",
2122
"SyncBatchNorm",
2223
"LayerNorm2d",
24+
"LayerNorm",
2325
]

cellseg_models_pytorch/modules/self_attention_modules.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def __init__(
1919
dropout: float = 0.0,
2020
bias: bool = False,
2121
slice_size: int = 4,
22+
**kwargs,
2223
) -> None:
2324
"""Compute self-attention.
2425
@@ -52,6 +53,7 @@ def __init__(
5253
`self_attention = "slice"`.
5354
"""
5455
super().__init__()
56+
self.out_channels = query_dim
5557
proj_channels = head_dim * num_heads
5658

5759
# cross attention dim

cellseg_models_pytorch/modules/transformers.py

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
from cellseg_models_pytorch.modules import SelfAttentionBlock
77

8-
from .base_modules import TransformerAct
9-
from .misc_modules import Proj2Attention
8+
from .mlp import MlpBlock
9+
from .patch_embeddings import ContiguousEmbed
1010

1111
__all__ = ["Transformer2D", "TransformerLayer"]
1212

@@ -25,7 +25,7 @@ def __init__(
2525
act: str = "geglu",
2626
num_groups: int = 32,
2727
slice_size: int = 4,
28-
fc_projection_mult: int = 4,
28+
mlp_ratio: int = 4,
2929
**kwargs,
3030
) -> None:
3131
"""Create a transformer for 2D-image-like (B, C, H, W) inputs.
@@ -69,15 +69,19 @@ def __init__(
6969
layer.
7070
"""
7171
super().__init__()
72-
self.proj_in = Proj2Attention(
72+
patch_norm = "gn" if in_channels >= 32 else None
73+
self.patch_embed = ContiguousEmbed(
7374
in_channels=in_channels,
74-
num_groups=num_groups,
75+
patch_size=1,
7576
head_dim=head_dim,
7677
num_heads=num_heads,
78+
normalization=patch_norm,
79+
norm_kwargs={"num_features": in_channels, "num_groups": num_groups},
7780
)
81+
self.proj_dim = self.patch_embed.proj_dim
7882

7983
self.transformer = TransformerLayer(
80-
query_dim=self.proj_in.proj_dim,
84+
query_dim=self.proj_dim,
8185
num_heads=num_heads,
8286
head_dim=head_dim,
8387
cross_attention_dim=cross_attention_dim,
@@ -87,11 +91,11 @@ def __init__(
8791
biases=biases,
8892
act=act,
8993
slice_size=slice_size,
90-
fc_projection_mult=fc_projection_mult,
94+
mlp_ratio=mlp_ratio,
9195
)
9296

9397
self.proj_out = nn.Conv2d(
94-
self.proj_in.proj_dim, in_channels, kernel_size=1, stride=1, padding=0
98+
self.proj_dim, in_channels, kernel_size=1, stride=1, padding=0
9599
)
96100

97101
def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
@@ -114,13 +118,13 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
114118
residual = x
115119

116120
# 1. project
117-
x = self.proj_in(x)
121+
x = self.patch_embed(x)
118122

119123
# 2. transformer
120124
x = self.transformer(x, context)
121125

122126
# 3. Reshape back to image-like shape and project to original input channels.
123-
x = x.reshape(B, H, W, self.proj_in.proj_dim).permute(0, 3, 1, 2)
127+
x = x.reshape(B, H, W, self.proj_dim).permute(0, 3, 1, 2)
124128
x = self.proj_out(x)
125129

126130
# 4. residual
@@ -134,13 +138,13 @@ def __init__(
134138
num_heads: int = 8,
135139
head_dim: int = 64,
136140
cross_attention_dim: int = None,
137-
act: str = "geglu",
141+
activation: str = "star_relu",
138142
n_blocks: int = 2,
139143
block_types: Tuple[str, ...] = ("basic", "basic"),
140144
dropouts: Tuple[float, ...] = (0.0, 0.0),
141145
biases: Tuple[bool, ...] = (False, False),
142146
slice_size: int = 4,
143-
fc_projection_mult: int = 4,
147+
mlp_ratio: int = 4,
144148
**kwargs,
145149
) -> None:
146150
"""Chain transformer blocks to compose a full generic transformer.
@@ -217,11 +221,12 @@ def __init__(
217221
)
218222
self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = att_block
219223

220-
proj_dim = int(query_dim * fc_projection_mult)
221-
self.fc = nn.Sequential(
222-
nn.LayerNorm(query_dim),
223-
TransformerAct(act, dim_in=query_dim, dim_out=proj_dim),
224-
nn.Linear(proj_dim, query_dim),
224+
self.mlp = MlpBlock(
225+
in_channels=query_dim,
226+
mlp_ratio=mlp_ratio,
227+
activation=activation,
228+
normalization="ln",
229+
norm_kwargs={"normalized_shape": query_dim},
225230
)
226231

227232
def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:
@@ -249,4 +254,4 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
249254

250255
x = tr_block(x, con)
251256

252-
return self.fc(x) + x
257+
return self.mlp(x) + x

0 commit comments

Comments
 (0)