1
- from typing import Tuple
1
+ from typing import Any , Dict , Optional , Tuple
2
2
3
3
import torch
4
4
import torch .nn as nn
5
+ import torch .nn .functional as F
5
6
6
- from cellseg_models_pytorch . modules import SelfAttentionBlock
7
-
7
+ from . base_modules import Identity
8
+ from . misc_modules import LayerScale
8
9
from .mlp import MlpBlock
9
10
from .patch_embeddings import ContiguousEmbed
11
+ from .self_attention_modules import SelfAttentionBlock
10
12
11
13
__all__ = ["Transformer2D" , "TransformerLayer" ]
12
14
@@ -23,10 +25,12 @@ def __init__(
23
25
computation_types : Tuple [str , ...] = ("basic" , "basic" ),
24
26
dropouts : Tuple [float , ...] = (0.0 , 0.0 ),
25
27
biases : Tuple [bool , ...] = (False , False ),
28
+ layer_scales : Tuple [bool , ...] = (False , False ),
26
29
activation : str = "star_relu" ,
27
30
num_groups : int = 32 ,
28
- slice_size : int = 4 ,
29
- mlp_ratio : int = 4 ,
31
+ mlp_ratio : int = 2 ,
32
+ slice_size : Optional [int ] = 4 ,
33
+ patch_embed_kwargs : Optional [Dict [str , Any ]] = None ,
30
34
** kwargs ,
31
35
) -> None :
32
36
"""Create a transformer for 2D-image-like (B, C, H, W) inputs.
@@ -49,7 +53,7 @@ def __init__(
49
53
n_blocks : int, default=2
50
54
Number of Multihead attention blocks in the transformer.
51
55
block_types : Tuple[str, ...], default=("exact", "exact")
52
- The name of the SelfAttentionBlocks in the TransformerLayer.
56
+ The names/types of the SelfAttentionBlocks in the TransformerLayer.
53
57
Length of the tuple has to equal `n_blocks`.
54
58
Allowed names: ("exact", "linformer").
55
59
computation_types : Tuple[str, ...], default=("basic", "basic")
@@ -60,18 +64,23 @@ def __init__(
60
64
Dropout probabilities for the SelfAttention blocks.
61
65
biases : bool, default=(True, True)
62
66
Include bias terms in the SelfAttention blocks.
67
+ layer_scales : bool, default=(False, False)
68
+ Learnable layer weights for the self-attention matrix.
63
69
activation : str, default="star_relu"
64
70
The activation function applied at the end of the transformer layer fc.
65
71
One of ("geglu", "approximate_gelu", "star_relu").
66
72
num_groups : int, default=32
67
73
Number of groups in the first group-norm op before the input is
68
74
projected to be suitable for self-attention.
69
- slice_size : int, default=4
75
+ mlp_ratio : int, default=2
76
+ Scaling factor for the number of input features to get the number of
77
+ hidden features in the final `Mlp` layer of the transformer.
78
+ slice_size : int, optional, default=4
70
79
Slice size for sliced self-attention. This is used only if
71
80
`name = "slice"` for a SelfAttentionBlock.
72
- mlp_ratio : int, default=4
73
- Multiplier that defines the out dimension of the final fc projection
74
- layer .
81
+ patch_embed_kwargs: Dict[str, Any], optional
82
+ Extra key-word arguments for the patch embedding module. See the
83
+ `ContiguousEmbed` module for more info .
75
84
"""
76
85
super ().__init__ ()
77
86
patch_norm = "gn" if in_channels >= 32 else None
@@ -82,6 +91,7 @@ def __init__(
82
91
num_heads = num_heads ,
83
92
normalization = patch_norm ,
84
93
norm_kwargs = {"num_features" : in_channels , "num_groups" : num_groups },
94
+ ** patch_embed_kwargs if patch_embed_kwargs is not None else {},
85
95
)
86
96
self .proj_dim = self .patch_embed .proj_dim
87
97
@@ -95,6 +105,7 @@ def __init__(
95
105
computation_types = computation_types ,
96
106
dropouts = dropouts ,
97
107
biases = biases ,
108
+ layer_scales = layer_scales ,
98
109
activation = activation ,
99
110
slice_size = slice_size ,
100
111
mlp_ratio = mlp_ratio ,
@@ -130,11 +141,22 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
130
141
# 2. transformer
131
142
x = self .transformer (x , context )
132
143
133
- # 3. Reshape back to image-like shape and project to original input channels.
134
- x = x .reshape (B , H , W , self .proj_dim ).permute (0 , 3 , 1 , 2 )
144
+ # 3. Reshape back to image-like shape.
145
+ p_H = self .patch_embed .get_patch_size (H )
146
+ p_W = self .patch_embed .get_patch_size (W )
147
+ x = x .reshape (B , p_H , p_W , self .proj_dim ).permute (0 , 3 , 1 , 2 )
148
+
149
+ # Upsample to input dims if patch size less than orig inp size
150
+ # assumes that the input is square mat.
151
+ # NOTE: the kernel_size, pad, & stride has to be set correctly for this to work
152
+ if p_H < H :
153
+ scale_factor = H // p_H
154
+ x = F .interpolate (x , scale_factor = scale_factor , mode = "bilinear" )
155
+
156
+ # 4. project to original input channels
135
157
x = self .proj_out (x )
136
158
137
- # 4 . residual
159
+ # 5 . residual
138
160
return x + residual
139
161
140
162
@@ -151,8 +173,9 @@ def __init__(
151
173
computation_types : Tuple [str , ...] = ("basic" , "basic" ),
152
174
dropouts : Tuple [float , ...] = (0.0 , 0.0 ),
153
175
biases : Tuple [bool , ...] = (False , False ),
154
- slice_size : int = 4 ,
155
- mlp_ratio : int = 4 ,
176
+ layer_scales : Tuple [bool , ...] = (False , False ),
177
+ mlp_ratio : int = 2 ,
178
+ slice_size : Optional [int ] = 4 ,
156
179
** kwargs ,
157
180
) -> None :
158
181
"""Chain transformer blocks to compose a full generic transformer.
@@ -191,12 +214,14 @@ def __init__(
191
214
Dropout probabilities for the SelfAttention blocks.
192
215
biases : bool, default=(True, True)
193
216
Include bias terms in the SelfAttention blocks.
194
- slice_size : int, default=4
217
+ layer_scales : bool, default=(False, False)
218
+ Learnable layer weights for the self-attention matrix.
219
+ mlp_ratio : int, default=2
220
+ Scaling factor for the number of input features to get the number of
221
+ hidden features in the final `Mlp` layer of the transformer.
222
+ slice_size : int, optional, default=4
195
223
Slice size for sliced self-attention. This is used only if
196
224
`name = "slice"` for a SelfAttentionBlock.
197
- mlp_proj : int, default=4
198
- Multiplier that defines the out dimension of the final fc projection
199
- layer.
200
225
**kwargs:
201
226
Arbitrary key-word arguments.
202
227
@@ -218,7 +243,9 @@ def __init__(
218
243
f"Illegal args: { illegal_args } "
219
244
)
220
245
221
- self .tr_blocks = nn .ModuleDict ()
246
+ # self.tr_blocks = nn.ModuleDict()
247
+ self .tr_blocks = nn .ModuleList ()
248
+ self .layer_scales = nn .ModuleList ()
222
249
blocks = list (range (n_blocks ))
223
250
for i in blocks :
224
251
cross_dim = cross_attention_dim if i == blocks [- 1 ] else None
@@ -235,7 +262,13 @@ def __init__(
235
262
slice_size = slice_size ,
236
263
** kwargs ,
237
264
)
238
- self .tr_blocks [f"transformer_{ block_types [i ]} _{ i + 1 } " ] = att_block
265
+ self .tr_blocks .append (att_block )
266
+
267
+ # add layer scale. (Optional)
268
+ ls = LayerScale (query_dim ) if layer_scales [i ] else Identity ()
269
+ self .layer_scales .append (ls )
270
+
271
+ # self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = tr_block
239
272
240
273
self .mlp = MlpBlock (
241
274
in_channels = query_dim ,
@@ -263,12 +296,14 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
263
296
Self-attended input tensor. Shape (B, H*W, query_dim).
264
297
"""
265
298
n_blocks = len (self .tr_blocks )
266
- for i , tr_block in enumerate (self .tr_blocks .values (), 1 ):
299
+
300
+ for i , (tr_block , ls ) in enumerate (zip (self .tr_blocks , self .layer_scales ), 1 ):
267
301
# apply context only at the last transformer block
268
302
con = None
269
303
if i == n_blocks :
270
304
con = context
271
305
272
306
x = tr_block (x , con )
307
+ x = ls (x )
273
308
274
309
return self .mlp (x ) + x
0 commit comments