Skip to content

Commit 5d712ff

Browse files
committed
feat: add layer scale + style, docs, t-hint fixes
1 parent 3a28df0 commit 5d712ff

File tree

5 files changed

+245
-148
lines changed

5 files changed

+245
-148
lines changed

cellseg_models_pytorch/decoders/decoder.py

Lines changed: 119 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, List, Optional, Tuple
1+
from typing import Any, Dict, List, Optional, Tuple, Union
22

33
import torch
44
import torch.nn as nn
@@ -13,50 +13,77 @@ def __init__(
1313
self,
1414
enc_channels: Tuple[int, ...],
1515
out_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
16-
style_channels: int = None,
17-
n_conv_layers: Tuple[int, ...] = (1, 1, 1, 1, 1),
18-
n_conv_blocks: Tuple[Tuple[int, ...], ...] = ((2,), (2,), (2,), (2,), (2,)),
19-
long_skip: str = "unet",
20-
n_transformers: Tuple[int, ...] = None,
21-
n_transformer_blocks: Tuple[Tuple[int], ...] = ((1,), (1,), (1,), (1,), (1,)),
16+
long_skip: Union[None, str, Tuple[str, ...]] = "unet",
17+
n_conv_layers: Union[None, int, Tuple[int, ...]] = 1,
18+
n_transformers: Union[None, int, Tuple[int, ...]] = None,
19+
n_conv_blocks: Union[int, Tuple[Tuple[int, ...], ...]] = 2,
20+
n_transformer_blocks: Union[int, Tuple[Tuple[int], ...]] = 1,
2221
stage_params: Optional[Tuple[Dict, ...]] = None,
22+
style_channels: int = None,
2323
**kwargs,
2424
) -> None:
2525
"""Build a generic U-net-like decoder.
2626
27+
I.e stack decoder stages that are composed followingly:
28+
29+
DecoderStage:
30+
- UpSample(up_method)
31+
- LongSkip(long_skip_method)
32+
- ConvLayer (optional)
33+
- ConvBlock(conv_block_method)
34+
- TransformerLayer (optional)
35+
- TransformerBlock(transformer_block_method)
36+
2737
Parameters
2838
----------
2939
enc_channels : Tuple[int, ...]
3040
Number of channels at each encoder layer.
3141
out_channels : Tuple[int, ...], default=(256, 128, 64, 32, 16)
3242
Number of channels at each decoder layer output.
33-
style_channels : int, default=None
34-
Number of style vector channels. If None, style vectors are ignored.
35-
n_conv_layers : Tuple[int, ...], default=(1, 1, 1, 1, 1)
36-
The number of conv layers inside each of the decoder stages.
37-
n_conv_blocks : Tuple[Tuple[int, ...], ...] =((2, ),(2, ),(2, ),(2, ),(2, ))
38-
The number of blocks inside each conv-layer at each decoder stage.
39-
long_skip : str, default="unet"
40-
long skip method to be used. One of: "unet", "unetpp", "unet3p",
41-
"unet3p-lite", None
42-
n_transformers : Tuple[int, ...], optional, default=None
43-
The number of transformer layers inside each of the decoder stages.
44-
n_transformer_blocks : Tuple[Tuple[int]] = ((1, ),(1, ),(1, ),(1, ),(1, ))
43+
long_skip : Union[None, str, Tuple[str, ...]], default="unet"
44+
long skip method to be used. The argument can be given as a tuple, where
45+
each value indicates the long-skip method for each stage of the decoder,
46+
allowing the mixing of long-skip methods in the decoder.
47+
Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
48+
n_conv_layers : Union[None, int, Tuple[int, ...]], default=1
49+
The number of convolution layers inside each of the decoder stages. The
50+
argument can be given as a tuple, where each value indicates the number
51+
of conv-layers inside each stage of the decoder allowing the mixing of
52+
different sized layers inside the stages in the decoder. If set to None,
53+
no conv-layers will be included in the decoder.
54+
n_transformers : Union[None, int, Tuple[int, ...]] , optional
55+
The number of transformer layers inside each of the decoder stages. The
56+
argument can be given as a tuple, where each value indicates the number
57+
of transformer-layers inside each stage of the decoder allowing the
58+
mixing of different sized layers inside the stages in the decoder. If
59+
set to None, no transformer layers will be included in the decoder.
60+
n_conv_blocks : Union[int, Tuple[Tuple[int, ...], ...]], default=2
61+
The number of blocks inside each conv-layer at each decoder stage. The
62+
argument can be given as a nested tuple, where each value indicates the
63+
number of `ConvBlock`s inside a single `ConvLayer` allowing different
64+
sized blocks inside each conv-layer in the decoder.
65+
n_transformer_blocks : Union[int, Tuple[Tuple[int], ...]], default=1
4566
The number of transformer blocks inside each transformer-layer at each
46-
decoder stage.
67+
decoder stage. The argument can be given as a nested tuple, where each
68+
value indicates the number of `SelfAttention`s inside a single
69+
`TranformerLayer` allowing different sized transformer blocks inside
70+
each transformer-layer in the decoder.
4771
stage_params : Optional[Tuple[Dict, ...]], default=None
4872
The keyword args for each of the distinct decoder stages. Incudes the
4973
parameters for the long skip connections, convolutional layers of the
5074
decoder and transformer layers itself. See the `DecoderStage`
5175
documentation for more info.
76+
style_channels : int, default=None
77+
Number of style vector channels. If None, style vectors are ignored.
78+
If `n_conv_layers` is None, this is ignored since style vectors are
79+
applied inside `ConvBlocks`.
5280
5381
Raises
5482
------
5583
ValueError:
5684
If there is a mismatch between encoder and decoder channel lengths.
5785
"""
5886
super().__init__()
59-
self.long_skip = long_skip
6087

6188
if not len(out_channels) == len(enc_channels):
6289
raise ValueError(
@@ -70,66 +97,105 @@ def __init__(
7097

7198
# scaling factor assumed to be 2 for the spatial dims and the input
7299
# has to be divisible by 32. 256 used here just for convenience.
73-
depth = len(out_channels)
74-
out_dims = [256 // 2**i for i in range(depth)][::-1]
100+
self.depth = len(out_channels)
101+
out_dims = [256 // 2**i for i in range(self.depth)][::-1]
75102

76-
# Build decoder
77-
for i in range(depth - 1):
78-
# number of conv layers
79-
n_clayers = None
80-
if n_conv_layers is not None:
81-
n_clayers = n_conv_layers[i]
82-
83-
# number of conv blocks inside each layer
84-
n_cblocks = None
85-
if n_conv_blocks is not None:
86-
n_cblocks = n_conv_blocks[i]
87-
88-
# number of transformer layers
89-
n_tr_layers = None
90-
if n_transformers is not None:
91-
n_tr_layers = n_transformers[i]
92-
93-
# number of transformer blocks inside transformer layers
94-
n_tr_blocks = None
95-
if n_transformer_blocks is not None:
96-
n_tr_blocks = n_transformer_blocks[i]
103+
# set layer-level tuple-args
104+
self.long_skips = self._layer_tuple(long_skip)
105+
n_conv_layers = self._layer_tuple(n_conv_layers)
106+
n_transformers = self._layer_tuple(n_transformers)
97107

108+
# set block-level tuple-args
109+
n_conv_blocks = self._block_tuple(n_conv_blocks, n_conv_layers)
110+
n_transformer_blocks = self._block_tuple(n_transformer_blocks, n_transformers)
111+
112+
# Build decoder
113+
for i in range(self.depth - 1):
98114
decoder_block = DecoderStage(
99115
stage_ix=i,
100116
dec_channels=tuple(out_channels),
101117
dec_dims=tuple(out_dims),
102118
skip_channels=skip_channels,
119+
long_skip=self._tup_arg(self.long_skips, i),
120+
n_conv_layers=self._tup_arg(n_conv_layers, i),
121+
n_conv_blocks=self._tup_arg(n_conv_blocks, i),
122+
n_transformers=self._tup_arg(n_transformers, i),
123+
n_transformer_blocks=self._tup_arg(n_transformer_blocks, i),
103124
style_channels=style_channels,
104-
long_skip=long_skip,
105-
n_conv_layers=n_clayers,
106-
n_conv_blocks=n_cblocks,
107-
n_transformers=n_tr_layers,
108-
n_transformer_blocks=n_tr_blocks,
109125
**stage_params[i] if stage_params is not None else {"k": None},
110126
)
111127
self.add_module(f"decoder_stage{i + 1}", decoder_block)
112128

113129
self.out_channels = decoder_block.out_channels
114130

131+
def _tup_arg(self, tup: Tuple[Any, ...], ix: int) -> Union[None, int, str]:
132+
"""Return None if given tuple-arg is None, else, return the value at ix."""
133+
ret = None
134+
if tup is not None:
135+
ret = tup[ix]
136+
return ret
137+
138+
def _layer_tuple(
139+
self, arg: Union[None, str, int, Tuple[Any, ...]]
140+
) -> Union[None, Tuple[Any, ...]]:
141+
"""Return a non-nested tuple or None for layer-related arguments."""
142+
ret = None
143+
if isinstance(arg, (list, tuple)):
144+
ret = tuple(arg)
145+
elif isinstance(arg, (str, int)):
146+
ret = tuple([arg] * self.depth)
147+
elif arg is None:
148+
ret = ret
149+
else:
150+
raise ValueError(
151+
f"Given arg: {arg} should be None, str, int or a Tuple of ints or strs."
152+
)
153+
154+
return ret
155+
156+
def _block_tuple(
157+
self,
158+
arg: Union[int, None, Tuple[Tuple[int, ...], ...]],
159+
n_layers: Tuple[int, ...],
160+
) -> Union[None, Tuple[Tuple[int, ...], ...]]:
161+
"""Return a nested tuple or None for block-related arguments."""
162+
ret = None
163+
if isinstance(arg, (list, tuple)):
164+
if not all([isinstance(a, (tuple, list)) for a in arg]):
165+
raise ValueError(
166+
f"Given arg: {arg} should be a nested sequence. Got: {arg}."
167+
)
168+
ret = tuple(arg)
169+
elif isinstance(arg, int):
170+
if n_layers is not None:
171+
ret = tuple([tuple([arg] * i) for i in n_layers])
172+
else:
173+
ret = None
174+
elif arg is None:
175+
ret = ret
176+
else:
177+
raise ValueError(f"Given arg: {arg} should be None, int or a nested tuple.")
178+
179+
return ret
180+
115181
def forward_features(
116182
self, features: Tuple[torch.Tensor], style: torch.Tensor = None
117183
) -> List[torch.Tensor]:
118184
"""Forward pass of the decoder. Returns all the decoder stage feats."""
119185
head = features[0]
120186
skips = features[1:]
121-
extra_skips = [head] if self.long_skip == "unet3p" else []
187+
extra_skips = [head] if self.long_skips[0] == "unet3p" else []
122188
ret_feats = []
123189

124190
x = head
125-
for decoder_stage in self.values():
191+
for i, decoder_stage in enumerate(self.values()):
126192
x, extra = decoder_stage(
127193
x, skips=skips, extra_skips=extra_skips, style=style
128194
)
129195

130-
if self.long_skip == "unetpp":
196+
if self.long_skips[i] == "unetpp":
131197
extra_skips = extra
132-
elif self.long_skip == "unet3p":
198+
elif self.long_skips[i] == "unet3p":
133199
extra_skips.append(x)
134200

135201
ret_feats.append(x)

cellseg_models_pytorch/decoders/decoder_stage.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,33 +16,34 @@ def __init__(
1616
dec_channels: Tuple[int, ...],
1717
dec_dims: Tuple[int, ...],
1818
skip_channels: Tuple[int, ...],
19-
style_channels: int = None,
20-
n_conv_layers: int = 1,
21-
n_conv_blocks: Tuple[int, ...] = (2,),
22-
short_skips: Tuple[str, ...] = ("residual",),
23-
expand_ratios: Tuple[float, float] = ((1.0, 1.0),),
24-
block_types: Tuple[Tuple[str, ...], ...] = (("basic", "basic"),),
25-
normalizations: Tuple[Tuple[str, ...], ...] = (("bn", "bn"),),
26-
activations: Tuple[Tuple[str, ...], ...] = (("relu", "relu"),),
27-
convolutions: Tuple[Tuple[str, ...], ...] = (("conv", "conv"),),
28-
attentions: Tuple[Tuple[str, ...], ...] = ((None, "se"),),
29-
preactivates: Tuple[Tuple[bool, ...], ...] = ((False, False),),
30-
preattends: Tuple[Tuple[bool, ...], ...] = ((False, False),),
31-
use_styles: Tuple[Tuple[bool, ...], ...] = ((False, False),),
32-
kernel_sizes: Tuple[Tuple[int, ...]] = ((3, 3),),
33-
groups: Tuple[Tuple[int, ...]] = ((1, 1),),
34-
biases: Tuple[Tuple[bool, ...]] = ((False, False),),
35-
layer_residual: bool = False,
36-
upsampling: str = "fixed-unpool",
3719
long_skip: str = "unet",
3820
merge_policy: str = "sum",
3921
skip_params: Optional[Dict[str, Any]] = None,
22+
upsampling: str = "fixed-unpool",
23+
n_conv_layers: Optional[int] = 1,
24+
style_channels: Optional[int] = None,
25+
layer_residual: Optional[bool] = False,
26+
n_conv_blocks: Optional[Tuple[int, ...]] = (2,),
27+
short_skips: Optional[Tuple[str, ...]] = ("residual",),
28+
expand_ratios: Optional[Tuple[float, float]] = ((1.0, 1.0),),
29+
block_types: Optional[Tuple[Tuple[str, ...], ...]] = (("basic", "basic"),),
30+
normalizations: Optional[Tuple[Tuple[str, ...], ...]] = (("bn", "bn"),),
31+
activations: Optional[Tuple[Tuple[str, ...], ...]] = (("relu", "relu"),),
32+
convolutions: Optional[Tuple[Tuple[str, ...], ...]] = (("conv", "conv"),),
33+
attentions: Optional[Tuple[Tuple[str, ...], ...]] = ((None, "se"),),
34+
preactivates: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
35+
preattends: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
36+
use_styles: Optional[Tuple[Tuple[bool, ...], ...]] = ((False, False),),
37+
kernel_sizes: Optional[Tuple[Tuple[int, ...]]] = ((3, 3),),
38+
groups: Optional[Tuple[Tuple[int, ...]]] = ((1, 1),),
39+
biases: Optional[Tuple[Tuple[bool, ...]]] = ((False, False),),
4040
n_transformers: Optional[int] = None,
4141
n_transformer_blocks: Optional[Tuple[int, ...]] = (1,),
4242
transformer_blocks: Optional[Tuple[Tuple[str, ...], ...]] = (("exact",),),
4343
transformer_computations: Optional[Tuple[Tuple[str, ...], ...]] = (("basic",),),
4444
transformer_biases: Optional[Tuple[Tuple[bool, ...], ...]] = ((False,),),
4545
transformer_dropouts: Optional[Tuple[Tuple[float, ...], ...]] = ((0.0,),),
46+
transformer_layer_scales: Optional[Tuple[Tuple[bool, ...], ...]] = ((False,),),
4647
transformer_params: Optional[List[Dict[str, Any]]] = None,
4748
**kwargs,
4849
) -> None:
@@ -67,15 +68,28 @@ def __init__(
6768
skip_channels : Tuple[int, ...]
6869
List of the number of channels in the encoder skip tensors. Ignored if
6970
`long_skip` == None.
71+
long_skip : str, default="unet"
72+
long skip method to be used.
73+
Allowed: "cross-attn", "unet", "unetpp", "unet3p", "unet3p-lite", None
74+
merge_policy : str, default="sum"
75+
The long skip merge policy. One of: "sum", "cat"
76+
skip_params : Optional[Dict]
77+
Extra keyword arguments for the skip-connection module. These depend
78+
on the skip module. Refer to specific skip modules for more info.
79+
upsampling : str, default="fixed-unpool"
80+
Name of the upsampling method.
81+
n_conv_layers : int, default=1
82+
The number of conv layers inside one decoder stage.
7083
style_channels : int, default=None
7184
Number of style vector channels. If None, style vectors are ignored.
7285
Also, ignored if `n_conv_layers` is None.
73-
n_conv_layers : int, default=1
74-
The number of conv layers inside one decoder stage.
86+
layer_residual : bool, optional, default=False
87+
Apply a layer level residual short skip at each layer. I.e x + layer(x).
88+
Ignored if `n_conv_layers` is None.
7589
n_conv_blocks : Tuple[int, ...], default=(2,)
7690
Number of conv-blocks inside each conv layer. The tuple-length has to
7791
match `n_conv_layers`. Ignored if `n_conv_layers` is None.
78-
short_skips : str, default=("residual", )
92+
short_skips : str, optional, default=("residual", )
7993
The short skip methods used inside the conv layers. Ignored if
8094
`n_conv_layers` is None.
8195
expand_ratios : Tuple[float, ...], default=((1.0, 1.0),):
@@ -122,18 +136,6 @@ def __init__(
122136
Include bias terms in the convolution blocks.
123137
The tuple-length has to match `n_conv_layers`. Ignored if
124138
`n_conv_layers` is None.
125-
upsampling : str, default="fixed-unpool"
126-
Name of the upsampling method.
127-
long_skip : str, default="unet"
128-
long skip method to be used. One of: "unet", "unetpp", "unet3p",
129-
"unet3p-lite", None,
130-
merge_policy : str, default="sum"
131-
The long skip merge policy. One of: "sum", "cat"
132-
layer_residual : bool, default=False
133-
Apply a layer level residual skip at each layer. I.e x + layer(x)
134-
skip_params : Optional[Dict]
135-
Extra keyword arguments for the skip-connection module. These depend
136-
on the skip module. Refer to specific skip modules for more info.
137139
n_transformers : int, optional
138140
Number of self-attention tranformers applied after the conv-layer.
139141
If this is None, no transformers will be added.
@@ -156,6 +158,9 @@ def __init__(
156158
transformer_dropoouts : Tuple[Tuple[float, ...], ...], default=((0.0,),)
157159
Dropout probabilities in the transformer layers. Ignored if
158160
`n_transformers` is None.
161+
transformer_layer_scales : Tuple[Tuple[bool, ...], ...], default=((False,),)
162+
Flags, whether to use layer scales in the transformer layers. Ignored if
163+
`n_transformers` is None.
159164
transformer_params : List[Dict[str, Any]]
160165
Extra keyword arguments for the transformer layers. Refer to
161166
`Transformer2D` module for more info. Ignored if `n_transformers`
@@ -211,6 +216,7 @@ def __init__(
211216
n_transformer_blocks,
212217
transformer_biases,
213218
transformer_dropouts,
219+
transformer_layer_scales,
214220
),
215221
)
216222

@@ -287,6 +293,7 @@ def __init__(
287293
computation_types=transformer_computations[i],
288294
biases=transformer_biases[i],
289295
dropouts=transformer_dropouts[i],
296+
layer_scales=transformer_layer_scales[i],
290297
**transformer_params
291298
if transformer_params is not None
292299
else {"k": None},

0 commit comments

Comments
 (0)