1
- from typing import Dict , List , Optional , Tuple
1
+ from typing import Any , Dict , List , Optional , Tuple , Union
2
2
3
3
import torch
4
4
import torch .nn as nn
@@ -13,50 +13,77 @@ def __init__(
13
13
self ,
14
14
enc_channels : Tuple [int , ...],
15
15
out_channels : Tuple [int , ...] = (256 , 128 , 64 , 32 , 16 ),
16
- style_channels : int = None ,
17
- n_conv_layers : Tuple [int , ...] = (1 , 1 , 1 , 1 , 1 ),
18
- n_conv_blocks : Tuple [Tuple [int , ...], ...] = ((2 ,), (2 ,), (2 ,), (2 ,), (2 ,)),
19
- long_skip : str = "unet" ,
20
- n_transformers : Tuple [int , ...] = None ,
21
- n_transformer_blocks : Tuple [Tuple [int ], ...] = ((1 ,), (1 ,), (1 ,), (1 ,), (1 ,)),
16
+ long_skip : Union [None , str , Tuple [str , ...]] = "unet" ,
17
+ n_conv_layers : Union [None , int , Tuple [int , ...]] = 1 ,
18
+ n_transformers : Union [None , int , Tuple [int , ...]] = None ,
19
+ n_conv_blocks : Union [int , Tuple [Tuple [int , ...], ...]] = 2 ,
20
+ n_transformer_blocks : Union [int , Tuple [Tuple [int ], ...]] = 1 ,
22
21
stage_params : Optional [Tuple [Dict , ...]] = None ,
22
+ style_channels : int = None ,
23
23
** kwargs ,
24
24
) -> None :
25
25
"""Build a generic U-net-like decoder.
26
26
27
+ I.e stack decoder stages that are composed followingly:
28
+
29
+ DecoderStage:
30
+ - UpSample(up_method)
31
+ - LongSkip(long_skip_method)
32
+ - ConvLayer (optional)
33
+ - ConvBlock(conv_block_method)
34
+ - TransformerLayer (optional)
35
+ - TransformerBlock(transformer_block_method)
36
+
27
37
Parameters
28
38
----------
29
39
enc_channels : Tuple[int, ...]
30
40
Number of channels at each encoder layer.
31
41
out_channels : Tuple[int, ...], default=(256, 128, 64, 32, 16)
32
42
Number of channels at each decoder layer output.
33
- style_channels : int, default=None
34
- Number of style vector channels. If None, style vectors are ignored.
35
- n_conv_layers : Tuple[int, ...], default=(1, 1, 1, 1, 1)
36
- The number of conv layers inside each of the decoder stages.
37
- n_conv_blocks : Tuple[Tuple[int, ...], ...] =((2, ),(2, ),(2, ),(2, ),(2, ))
38
- The number of blocks inside each conv-layer at each decoder stage.
39
- long_skip : str, default="unet"
40
- long skip method to be used. One of: "unet", "unetpp", "unet3p",
41
- "unet3p-lite", None
42
- n_transformers : Tuple[int, ...], optional, default=None
43
- The number of transformer layers inside each of the decoder stages.
44
- n_transformer_blocks : Tuple[Tuple[int]] = ((1, ),(1, ),(1, ),(1, ),(1, ))
43
+ long_skip : Union[None, str, Tuple[str, ...]], default="unet"
44
+ long skip method to be used. The argument can be given as a tuple, where
45
+ each value indicates the long-skip method for each stage of the decoder,
46
+ allowing the mixing of long-skip methods in the decoder.
47
+ Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
48
+ n_conv_layers : Union[None, int, Tuple[int, ...]], default=1
49
+ The number of convolution layers inside each of the decoder stages. The
50
+ argument can be given as a tuple, where each value indicates the number
51
+ of conv-layers inside each stage of the decoder allowing the mixing of
52
+ different sized layers inside the stages in the decoder. If set to None,
53
+ no conv-layers will be included in the decoder.
54
+ n_transformers : Union[None, int, Tuple[int, ...]] , optional
55
+ The number of transformer layers inside each of the decoder stages. The
56
+ argument can be given as a tuple, where each value indicates the number
57
+ of transformer-layers inside each stage of the decoder allowing the
58
+ mixing of different sized layers inside the stages in the decoder. If
59
+ set to None, no transformer layers will be included in the decoder.
60
+ n_conv_blocks : Union[int, Tuple[Tuple[int, ...], ...]], default=2
61
+ The number of blocks inside each conv-layer at each decoder stage. The
62
+ argument can be given as a nested tuple, where each value indicates the
63
+ number of `ConvBlock`s inside a single `ConvLayer` allowing different
64
+ sized blocks inside each conv-layer in the decoder.
65
+ n_transformer_blocks : Union[int, Tuple[Tuple[int], ...]], default=1
45
66
The number of transformer blocks inside each transformer-layer at each
46
- decoder stage.
67
+ decoder stage. The argument can be given as a nested tuple, where each
68
+ value indicates the number of `SelfAttention`s inside a single
69
+ `TranformerLayer` allowing different sized transformer blocks inside
70
+ each transformer-layer in the decoder.
47
71
stage_params : Optional[Tuple[Dict, ...]], default=None
48
72
The keyword args for each of the distinct decoder stages. Incudes the
49
73
parameters for the long skip connections, convolutional layers of the
50
74
decoder and transformer layers itself. See the `DecoderStage`
51
75
documentation for more info.
76
+ style_channels : int, default=None
77
+ Number of style vector channels. If None, style vectors are ignored.
78
+ If `n_conv_layers` is None, this is ignored since style vectors are
79
+ applied inside `ConvBlocks`.
52
80
53
81
Raises
54
82
------
55
83
ValueError:
56
84
If there is a mismatch between encoder and decoder channel lengths.
57
85
"""
58
86
super ().__init__ ()
59
- self .long_skip = long_skip
60
87
61
88
if not len (out_channels ) == len (enc_channels ):
62
89
raise ValueError (
@@ -70,66 +97,105 @@ def __init__(
70
97
71
98
# scaling factor assumed to be 2 for the spatial dims and the input
72
99
# has to be divisible by 32. 256 used here just for convenience.
73
- depth = len (out_channels )
74
- out_dims = [256 // 2 ** i for i in range (depth )][::- 1 ]
100
+ self . depth = len (out_channels )
101
+ out_dims = [256 // 2 ** i for i in range (self . depth )][::- 1 ]
75
102
76
- # Build decoder
77
- for i in range (depth - 1 ):
78
- # number of conv layers
79
- n_clayers = None
80
- if n_conv_layers is not None :
81
- n_clayers = n_conv_layers [i ]
82
-
83
- # number of conv blocks inside each layer
84
- n_cblocks = None
85
- if n_conv_blocks is not None :
86
- n_cblocks = n_conv_blocks [i ]
87
-
88
- # number of transformer layers
89
- n_tr_layers = None
90
- if n_transformers is not None :
91
- n_tr_layers = n_transformers [i ]
92
-
93
- # number of transformer blocks inside transformer layers
94
- n_tr_blocks = None
95
- if n_transformer_blocks is not None :
96
- n_tr_blocks = n_transformer_blocks [i ]
103
+ # set layer-level tuple-args
104
+ self .long_skips = self ._layer_tuple (long_skip )
105
+ n_conv_layers = self ._layer_tuple (n_conv_layers )
106
+ n_transformers = self ._layer_tuple (n_transformers )
97
107
108
+ # set block-level tuple-args
109
+ n_conv_blocks = self ._block_tuple (n_conv_blocks , n_conv_layers )
110
+ n_transformer_blocks = self ._block_tuple (n_transformer_blocks , n_transformers )
111
+
112
+ # Build decoder
113
+ for i in range (self .depth - 1 ):
98
114
decoder_block = DecoderStage (
99
115
stage_ix = i ,
100
116
dec_channels = tuple (out_channels ),
101
117
dec_dims = tuple (out_dims ),
102
118
skip_channels = skip_channels ,
119
+ long_skip = self ._tup_arg (self .long_skips , i ),
120
+ n_conv_layers = self ._tup_arg (n_conv_layers , i ),
121
+ n_conv_blocks = self ._tup_arg (n_conv_blocks , i ),
122
+ n_transformers = self ._tup_arg (n_transformers , i ),
123
+ n_transformer_blocks = self ._tup_arg (n_transformer_blocks , i ),
103
124
style_channels = style_channels ,
104
- long_skip = long_skip ,
105
- n_conv_layers = n_clayers ,
106
- n_conv_blocks = n_cblocks ,
107
- n_transformers = n_tr_layers ,
108
- n_transformer_blocks = n_tr_blocks ,
109
125
** stage_params [i ] if stage_params is not None else {"k" : None },
110
126
)
111
127
self .add_module (f"decoder_stage{ i + 1 } " , decoder_block )
112
128
113
129
self .out_channels = decoder_block .out_channels
114
130
131
+ def _tup_arg (self , tup : Tuple [Any , ...], ix : int ) -> Union [None , int , str ]:
132
+ """Return None if given tuple-arg is None, else, return the value at ix."""
133
+ ret = None
134
+ if tup is not None :
135
+ ret = tup [ix ]
136
+ return ret
137
+
138
+ def _layer_tuple (
139
+ self , arg : Union [None , str , int , Tuple [Any , ...]]
140
+ ) -> Union [None , Tuple [Any , ...]]:
141
+ """Return a non-nested tuple or None for layer-related arguments."""
142
+ ret = None
143
+ if isinstance (arg , (list , tuple )):
144
+ ret = tuple (arg )
145
+ elif isinstance (arg , (str , int )):
146
+ ret = tuple ([arg ] * self .depth )
147
+ elif arg is None :
148
+ ret = ret
149
+ else :
150
+ raise ValueError (
151
+ f"Given arg: { arg } should be None, str, int or a Tuple of ints or strs."
152
+ )
153
+
154
+ return ret
155
+
156
+ def _block_tuple (
157
+ self ,
158
+ arg : Union [int , None , Tuple [Tuple [int , ...], ...]],
159
+ n_layers : Tuple [int , ...],
160
+ ) -> Union [None , Tuple [Tuple [int , ...], ...]]:
161
+ """Return a nested tuple or None for block-related arguments."""
162
+ ret = None
163
+ if isinstance (arg , (list , tuple )):
164
+ if not all ([isinstance (a , (tuple , list )) for a in arg ]):
165
+ raise ValueError (
166
+ f"Given arg: { arg } should be a nested sequence. Got: { arg } ."
167
+ )
168
+ ret = tuple (arg )
169
+ elif isinstance (arg , int ):
170
+ if n_layers is not None :
171
+ ret = tuple ([tuple ([arg ] * i ) for i in n_layers ])
172
+ else :
173
+ ret = None
174
+ elif arg is None :
175
+ ret = ret
176
+ else :
177
+ raise ValueError (f"Given arg: { arg } should be None, int or a nested tuple." )
178
+
179
+ return ret
180
+
115
181
def forward_features (
116
182
self , features : Tuple [torch .Tensor ], style : torch .Tensor = None
117
183
) -> List [torch .Tensor ]:
118
184
"""Forward pass of the decoder. Returns all the decoder stage feats."""
119
185
head = features [0 ]
120
186
skips = features [1 :]
121
- extra_skips = [head ] if self .long_skip == "unet3p" else []
187
+ extra_skips = [head ] if self .long_skips [ 0 ] == "unet3p" else []
122
188
ret_feats = []
123
189
124
190
x = head
125
- for decoder_stage in self .values ():
191
+ for i , decoder_stage in enumerate ( self .values () ):
126
192
x , extra = decoder_stage (
127
193
x , skips = skips , extra_skips = extra_skips , style = style
128
194
)
129
195
130
- if self .long_skip == "unetpp" :
196
+ if self .long_skips [ i ] == "unetpp" :
131
197
extra_skips = extra
132
- elif self .long_skip == "unet3p" :
198
+ elif self .long_skips [ i ] == "unet3p" :
133
199
extra_skips .append (x )
134
200
135
201
ret_feats .append (x )
0 commit comments