@@ -17,8 +17,8 @@ def __init__(
17
17
dec_dims : Tuple [int , ...],
18
18
skip_channels : Tuple [int , ...],
19
19
style_channels : int = None ,
20
- n_layers : int = 1 ,
21
- n_blocks : Tuple [int , ...] = (2 ,),
20
+ n_conv_layers : int = 1 ,
21
+ n_conv_blocks : Tuple [int , ...] = (2 ,),
22
22
short_skips : Tuple [str , ...] = ("residual" ,),
23
23
expand_ratios : Tuple [float , float ] = ((1.0 , 1.0 ),),
24
24
block_types : Tuple [Tuple [str , ...], ...] = (("basic" , "basic" ),),
@@ -32,14 +32,15 @@ def __init__(
32
32
kernel_sizes : Tuple [Tuple [int , ...]] = ((3 , 3 ),),
33
33
groups : Tuple [Tuple [int , ...]] = ((1 , 1 ),),
34
34
biases : Tuple [Tuple [bool , ...]] = ((False , False ),),
35
+ layer_residual : bool = False ,
35
36
upsampling : str = "fixed-unpool" ,
36
37
long_skip : str = "unet" ,
37
38
merge_policy : str = "sum" ,
38
- layer_residual : bool = False ,
39
39
skip_params : Optional [Dict [str , Any ]] = None ,
40
40
n_transformers : Optional [int ] = None ,
41
41
n_transformer_blocks : Optional [Tuple [int , ...]] = (1 ,),
42
- self_attentions : Optional [Tuple [Tuple [str , ...], ...]] = (("basic" ,),),
42
+ transformer_blocks : Optional [Tuple [Tuple [str , ...], ...]] = (("exact" ,),),
43
+ transformer_computations : Optional [Tuple [Tuple [str , ...], ...]] = (("basic" ,),),
43
44
transformer_biases : Optional [Tuple [Tuple [bool , ...], ...]] = ((False ,),),
44
45
transformer_dropouts : Optional [Tuple [Tuple [float , ...], ...]] = ((0.0 ,),),
45
46
transformer_params : Optional [List [Dict [str , Any ]]] = None ,
@@ -68,48 +69,59 @@ def __init__(
68
69
`long_skip` == None.
69
70
style_channels : int, default=None
70
71
Number of style vector channels. If None, style vectors are ignored.
71
- Also, ignored if `n_layers ` is None.
72
- n_layers : int, default=1
72
+ Also, ignored if `n_conv_layers ` is None.
73
+ n_conv_layers : int, default=1
73
74
The number of conv layers inside one decoder stage.
74
- n_blocks : Tuple[int, ...], default=(2,)
75
+ n_conv_blocks : Tuple[int, ...], default=(2,)
75
76
Number of conv-blocks inside each conv layer. The tuple-length has to
76
- match `n_layers `. Ignored if `n_layers ` is None.
77
+ match `n_conv_layers `. Ignored if `n_conv_layers ` is None.
77
78
short_skips : str, default=("residual", )
78
79
The short skip methods used inside the conv layers. Ignored if
79
- `n_layers ` is None.
80
+ `n_conv_layers ` is None.
80
81
expand_ratios : Tuple[float, ...], default=((1.0, 1.0),):
81
82
Expansion/Squeeze ratios for the out channels of each conv block.
82
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
83
+ The tuple-length has to match `n_conv_layers`. Ignored if
84
+ `n_conv_layers` is None.
83
85
block_types : Tuple[Tuple[str, ...], ...], default=(("basic", "basic"), )
84
86
The type of the convolution blocks in the conv blocks inside the layers.
85
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
87
+ The tuple-length has to match `n_conv_layers`. Ignored if
88
+ `n_conv_layers` is None.
86
89
normalizations : Tuple[Tuple[str, ...], ...], default: (("bn", "bn"), )
87
90
Normalization methods used in the conv blocks inside the conv layers.
88
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
91
+ The tuple-length has to match `n_conv_layers`. Ignored if
92
+ `n_conv_layers` is None.
89
93
activations : Tuple[Tuple[str, ...], ...], default: (("relu", "relu"), )
90
94
Activation methods used inside the conv layers.
91
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
95
+ The tuple-length has to match `n_conv_layers`. Ignored if
96
+ `n_conv_layers` is None.
92
97
attentions : Tuple[Tuple[str, ...], ...], default: ((None, "se"), )
93
98
Attention methods used inside the conv layers.
94
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
99
+ The tuple-length has to match `n_conv_layers`. Ignored if
100
+ `n_conv_layers` is None.
95
101
preactivates Tuple[Tuple[bool, ...], ...], default: ((False, False), )
96
102
Boolean flags for the conv layers to use pre-activation.
97
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
103
+ The tuple-length has to match `n_conv_layers`. Ignored if
104
+ `n_conv_layers` is None.
98
105
preattends Tuple[Tuple[bool, ...], ...], default: ((False, False), )
99
106
Boolean flags for the conv layers to use pre-activation.
100
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
107
+ The tuple-length has to match `n_conv_layers`. Ignored if
108
+ `n_conv_layers` is None.
101
109
use_styles : Tuple[Tuple[bool, ...], ...], default=((False, False), )
102
110
Boolean flags for the conv layers to add style vectors at each block.
103
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
111
+ The tuple-length has to match `n_conv_layers`. Ignored if
112
+ `n_conv_layers` is None.
104
113
kernel_sizes : Tuple[int, ...], default=((3, 3),)
105
114
The size of the convolution kernels in each conv block.
106
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
115
+ The tuple-length has to match `n_conv_layers`. Ignored if
116
+ `n_conv_layers` is None.
107
117
groups : int, default=((1, 1),)
108
118
Number of groups for the kernels in each convolution blocks.
109
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
119
+ The tuple-length has to match `n_conv_layers`. Ignored if
120
+ `n_conv_layers` is None.
110
121
biases : bool, default=((False, False),)
111
122
Include bias terms in the convolution blocks.
112
- The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
123
+ The tuple-length has to match `n_conv_layers`. Ignored if
124
+ `n_conv_layers` is None.
113
125
upsampling : str, default="fixed-unpool"
114
126
Name of the upsampling method.
115
127
long_skip : str, default="unet"
@@ -128,10 +140,16 @@ def __init__(
128
140
n_transformer_blocks : int, default=(2, ), optional
129
141
Number of multi-head self attention blocks used in the transformer
130
142
layers. Ignored if `n_transformers` is None.
131
- self_attentions : Tuple[Tuple[str, ...], ...], default=(("basic",),)
132
- The self-attention mechanisms used in the transformer layers .
143
+ transformer_blocks : Tuple[Tuple[str, ...], ...], default=(("basic",),)
144
+ The name of the SelfAttentionBlocks in the TransformerLayer(s) .
133
145
Allowed values: "basic", "slice", "flash". Ignored if `n_transformers`
134
- is None.
146
+ is None. Length of the tuple has to equal `n_transformer_blocks`.
147
+ Allowed names: ("exact", "linformer").
148
+ transformer_computations : Tuple[Tuple[str, ...],...], default=(("basic",),)
149
+ The way of computing the attention matrices in the SelfAttentionBlocks
150
+ in the TransformerLayer(s). Length of the tuple has to equal
151
+ `n_transformer_blocks`. Allowed styles: "basic". "slice", "flash",
152
+ "memeff", "slice-memeff".
135
153
transformer_biases : Tuple[Tuple[bool, ...], ...], default=((False,),)
136
154
Flags, whether to use biases in the transformer layers. Ignored if
137
155
`n_transformers` is None.
@@ -146,13 +164,13 @@ def __init__(
146
164
Raises
147
165
------
148
166
ValueError:
149
- If lengths of the conv layer tuple args are not equal to `n_layers`.
167
+ If lengths of the conv layer tuple args are not equal to `n_conv_layers`
150
168
If lengths of the transformer layer tuple args are not equal to
151
169
`n_transformers`.
152
170
"""
153
171
super ().__init__ ()
154
172
155
- self .n_layers = n_layers
173
+ self .n_conv_layers = n_conv_layers
156
174
self .n_transformers = n_transformers
157
175
self .long_skip = long_skip
158
176
self .stage_ix = stage_ix
@@ -176,19 +194,20 @@ def __init__(
176
194
177
195
# Set up n layers of conv blocks
178
196
layer = None # placeholder
179
- if n_layers is not None :
197
+ if n_conv_layers is not None :
180
198
181
199
# check that the conv-layer tuple-args are not illegal.
182
200
self ._check_tuple_args (
183
201
"conv-layer related" ,
184
- "n_layers " ,
185
- n_layers ,
202
+ "n_conv_layers " ,
203
+ n_conv_layers ,
186
204
all_args = locals (),
187
205
skip_args = (
188
206
skip_channels ,
189
207
dec_channels ,
190
208
dec_dims ,
191
- self_attentions ,
209
+ transformer_blocks ,
210
+ transformer_computations ,
192
211
n_transformer_blocks ,
193
212
transformer_biases ,
194
213
transformer_dropouts ,
@@ -197,12 +216,12 @@ def __init__(
197
216
198
217
# set up the conv-layers.
199
218
self .conv_layers = nn .ModuleDict ()
200
- for i in range (n_layers ):
219
+ for i in range (n_conv_layers ):
201
220
n_in_feats = self .skip .out_channels if i == 0 else layer .out_channels
202
221
layer = ConvLayer (
203
222
in_channels = n_in_feats ,
204
223
out_channels = self .out_channels ,
205
- n_blocks = n_blocks [i ],
224
+ n_blocks = n_conv_blocks [i ],
206
225
layer_residual = layer_residual ,
207
226
style_channels = style_channels ,
208
227
short_skip = short_skips [i ],
@@ -225,7 +244,9 @@ def __init__(
225
244
self .out_channels = layer .out_channels
226
245
227
246
# set in_channels for final operations
228
- in_channels = self .skip .out_channels if n_layers is None else self .out_channels
247
+ in_channels = (
248
+ self .skip .out_channels if n_conv_layers is None else self .out_channels
249
+ )
229
250
230
251
if n_transformers is not None :
231
252
@@ -239,7 +260,7 @@ def __init__(
239
260
skip_channels ,
240
261
dec_channels ,
241
262
dec_dims ,
242
- n_blocks ,
263
+ n_conv_blocks ,
243
264
short_skips ,
244
265
expand_ratios ,
245
266
block_types ,
@@ -262,7 +283,8 @@ def __init__(
262
283
tr = Transformer2D (
263
284
in_channels = in_channels ,
264
285
n_blocks = n_transformer_blocks [i ],
265
- block_types = self_attentions [i ],
286
+ block_types = transformer_blocks [i ],
287
+ computation_types = transformer_computations [i ],
266
288
biases = transformer_biases [i ],
267
289
dropouts = transformer_dropouts [i ],
268
290
** transformer_params
@@ -272,7 +294,7 @@ def __init__(
272
294
self .transformers [f"tr_layer_{ i + 1 } " ] = tr
273
295
274
296
# add a channel pooling layer at the end if no conv-layers are set up
275
- if n_layers is None :
297
+ if n_conv_layers is None :
276
298
self .ch_pool = ChannelPool (
277
299
in_channels = in_channels ,
278
300
out_channels = self .out_channels ,
@@ -338,7 +360,7 @@ def forward(
338
360
x = x [0 ] if self .long_skip == "unetpp" else x
339
361
340
362
# conv layers
341
- if self .n_layers is not None :
363
+ if self .n_conv_layers is not None :
342
364
for conv_layer in self .conv_layers .values ():
343
365
x = conv_layer (x , style ) # (B, out_channels, H, W)
344
366
@@ -348,7 +370,7 @@ def forward(
348
370
x = transformer (x ) # (B, long_skip_channels/out_channels, H, W)
349
371
350
372
# channel pool if conv-layers are skipped.
351
- if self .n_layers is None :
373
+ if self .n_conv_layers is None :
352
374
x = self .ch_pool (x ) # (B, out_channels, H, W)
353
375
354
376
return x , extra_skips
0 commit comments