Skip to content

Commit 5d89980

Browse files
committed
refactor: change variable names for readability
1 parent 698242e commit 5d89980

File tree

6 files changed

+78
-56
lines changed

6 files changed

+78
-56
lines changed

cellseg_models_pytorch/decoders/decoder.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ def __init__(
1414
enc_channels: Tuple[int, ...],
1515
out_channels: Tuple[int, ...] = (256, 128, 64, 32, 16),
1616
style_channels: int = None,
17-
n_layers: Tuple[int, ...] = (1, 1, 1, 1, 1),
18-
n_blocks: Tuple[Tuple[int, ...], ...] = ((2,), (2,), (2,), (2,), (2,)),
17+
n_conv_layers: Tuple[int, ...] = (1, 1, 1, 1, 1),
18+
n_conv_blocks: Tuple[Tuple[int, ...], ...] = ((2,), (2,), (2,), (2,), (2,)),
1919
long_skip: str = "unet",
2020
n_transformers: Tuple[int, ...] = None,
2121
n_transformer_blocks: Tuple[Tuple[int], ...] = ((1,), (1,), (1,), (1,), (1,)),
@@ -32,9 +32,9 @@ def __init__(
3232
Number of channels at each decoder layer output.
3333
style_channels : int, default=None
3434
Number of style vector channels. If None, style vectors are ignored.
35-
n_layers : Tuple[int, ...], default=(1, 1, 1, 1, 1)
35+
n_conv_layers : Tuple[int, ...], default=(1, 1, 1, 1, 1)
3636
The number of conv layers inside each of the decoder stages.
37-
n_blocks : Tuple[Tuple[int, ...], ...] = ((2, ),(2, ),(2, ),(2, ),(2, ))
37+
n_conv_blocks : Tuple[Tuple[int, ...], ...] =((2, ),(2, ),(2, ),(2, ),(2, ))
3838
The number of blocks inside each conv-layer at each decoder stage.
3939
long_skip : str, default="unet"
4040
long skip method to be used. One of: "unet", "unetpp", "unet3p",
@@ -76,14 +76,14 @@ def __init__(
7676
# Build decoder
7777
for i in range(depth - 1):
7878
# number of conv layers
79-
n_conv_layers = None
80-
if n_layers is not None:
81-
n_conv_layers = n_layers[i]
79+
n_clayers = None
80+
if n_conv_layers is not None:
81+
n_clayers = n_conv_layers[i]
8282

8383
# number of conv blocks inside each layer
84-
n_conv_blocks = None
85-
if n_blocks is not None:
86-
n_conv_blocks = n_blocks[i]
84+
n_cblocks = None
85+
if n_conv_blocks is not None:
86+
n_cblocks = n_conv_blocks[i]
8787

8888
# number of transformer layers
8989
n_tr_layers = None
@@ -102,8 +102,8 @@ def __init__(
102102
skip_channels=skip_channels,
103103
style_channels=style_channels,
104104
long_skip=long_skip,
105-
n_layers=n_conv_layers,
106-
n_blocks=n_conv_blocks,
105+
n_conv_layers=n_clayers,
106+
n_conv_blocks=n_cblocks,
107107
n_transformers=n_tr_layers,
108108
n_transformer_blocks=n_tr_blocks,
109109
**stage_params[i] if stage_params is not None else {"k": None},

cellseg_models_pytorch/decoders/decoder_stage.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ def __init__(
1717
dec_dims: Tuple[int, ...],
1818
skip_channels: Tuple[int, ...],
1919
style_channels: int = None,
20-
n_layers: int = 1,
21-
n_blocks: Tuple[int, ...] = (2,),
20+
n_conv_layers: int = 1,
21+
n_conv_blocks: Tuple[int, ...] = (2,),
2222
short_skips: Tuple[str, ...] = ("residual",),
2323
expand_ratios: Tuple[float, float] = ((1.0, 1.0),),
2424
block_types: Tuple[Tuple[str, ...], ...] = (("basic", "basic"),),
@@ -32,14 +32,15 @@ def __init__(
3232
kernel_sizes: Tuple[Tuple[int, ...]] = ((3, 3),),
3333
groups: Tuple[Tuple[int, ...]] = ((1, 1),),
3434
biases: Tuple[Tuple[bool, ...]] = ((False, False),),
35+
layer_residual: bool = False,
3536
upsampling: str = "fixed-unpool",
3637
long_skip: str = "unet",
3738
merge_policy: str = "sum",
38-
layer_residual: bool = False,
3939
skip_params: Optional[Dict[str, Any]] = None,
4040
n_transformers: Optional[int] = None,
4141
n_transformer_blocks: Optional[Tuple[int, ...]] = (1,),
42-
self_attentions: Optional[Tuple[Tuple[str, ...], ...]] = (("basic",),),
42+
transformer_blocks: Optional[Tuple[Tuple[str, ...], ...]] = (("exact",),),
43+
transformer_computations: Optional[Tuple[Tuple[str, ...], ...]] = (("basic",),),
4344
transformer_biases: Optional[Tuple[Tuple[bool, ...], ...]] = ((False,),),
4445
transformer_dropouts: Optional[Tuple[Tuple[float, ...], ...]] = ((0.0,),),
4546
transformer_params: Optional[List[Dict[str, Any]]] = None,
@@ -68,48 +69,59 @@ def __init__(
6869
`long_skip` == None.
6970
style_channels : int, default=None
7071
Number of style vector channels. If None, style vectors are ignored.
71-
Also, ignored if `n_layers` is None.
72-
n_layers : int, default=1
72+
Also, ignored if `n_conv_layers` is None.
73+
n_conv_layers : int, default=1
7374
The number of conv layers inside one decoder stage.
74-
n_blocks : Tuple[int, ...], default=(2,)
75+
n_conv_blocks : Tuple[int, ...], default=(2,)
7576
Number of conv-blocks inside each conv layer. The tuple-length has to
76-
match `n_layers`. Ignored if `n_layers` is None.
77+
match `n_conv_layers`. Ignored if `n_conv_layers` is None.
7778
short_skips : str, default=("residual", )
7879
The short skip methods used inside the conv layers. Ignored if
79-
`n_layers` is None.
80+
`n_conv_layers` is None.
8081
expand_ratios : Tuple[float, ...], default=((1.0, 1.0),):
8182
Expansion/Squeeze ratios for the out channels of each conv block.
82-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
83+
The tuple-length has to match `n_conv_layers`. Ignored if
84+
`n_conv_layers` is None.
8385
block_types : Tuple[Tuple[str, ...], ...], default=(("basic", "basic"), )
8486
The type of the convolution blocks in the conv blocks inside the layers.
85-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
87+
The tuple-length has to match `n_conv_layers`. Ignored if
88+
`n_conv_layers` is None.
8689
normalizations : Tuple[Tuple[str, ...], ...], default: (("bn", "bn"), )
8790
Normalization methods used in the conv blocks inside the conv layers.
88-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
91+
The tuple-length has to match `n_conv_layers`. Ignored if
92+
`n_conv_layers` is None.
8993
activations : Tuple[Tuple[str, ...], ...], default: (("relu", "relu"), )
9094
Activation methods used inside the conv layers.
91-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
95+
The tuple-length has to match `n_conv_layers`. Ignored if
96+
`n_conv_layers` is None.
9297
attentions : Tuple[Tuple[str, ...], ...], default: ((None, "se"), )
9398
Attention methods used inside the conv layers.
94-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
99+
The tuple-length has to match `n_conv_layers`. Ignored if
100+
`n_conv_layers` is None.
95101
preactivates Tuple[Tuple[bool, ...], ...], default: ((False, False), )
96102
Boolean flags for the conv layers to use pre-activation.
97-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
103+
The tuple-length has to match `n_conv_layers`. Ignored if
104+
`n_conv_layers` is None.
98105
preattends Tuple[Tuple[bool, ...], ...], default: ((False, False), )
99106
Boolean flags for the conv layers to use pre-activation.
100-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
107+
The tuple-length has to match `n_conv_layers`. Ignored if
108+
`n_conv_layers` is None.
101109
use_styles : Tuple[Tuple[bool, ...], ...], default=((False, False), )
102110
Boolean flags for the conv layers to add style vectors at each block.
103-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
111+
The tuple-length has to match `n_conv_layers`. Ignored if
112+
`n_conv_layers` is None.
104113
kernel_sizes : Tuple[int, ...], default=((3, 3),)
105114
The size of the convolution kernels in each conv block.
106-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
115+
The tuple-length has to match `n_conv_layers`. Ignored if
116+
`n_conv_layers` is None.
107117
groups : int, default=((1, 1),)
108118
Number of groups for the kernels in each convolution blocks.
109-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
119+
The tuple-length has to match `n_conv_layers`. Ignored if
120+
`n_conv_layers` is None.
110121
biases : bool, default=((False, False),)
111122
Include bias terms in the convolution blocks.
112-
The tuple-length has to match `n_layers`. Ignored if `n_layers` is None.
123+
The tuple-length has to match `n_conv_layers`. Ignored if
124+
`n_conv_layers` is None.
113125
upsampling : str, default="fixed-unpool"
114126
Name of the upsampling method.
115127
long_skip : str, default="unet"
@@ -128,10 +140,16 @@ def __init__(
128140
n_transformer_blocks : int, default=(2, ), optional
129141
Number of multi-head self attention blocks used in the transformer
130142
layers. Ignored if `n_transformers` is None.
131-
self_attentions : Tuple[Tuple[str, ...], ...], default=(("basic",),)
132-
The self-attention mechanisms used in the transformer layers.
143+
transformer_blocks : Tuple[Tuple[str, ...], ...], default=(("basic",),)
144+
The name of the SelfAttentionBlocks in the TransformerLayer(s).
133145
Allowed values: "basic", "slice", "flash". Ignored if `n_transformers`
134-
is None.
146+
is None. Length of the tuple has to equal `n_transformer_blocks`.
147+
Allowed names: ("exact", "linformer").
148+
transformer_computations : Tuple[Tuple[str, ...],...], default=(("basic",),)
149+
The way of computing the attention matrices in the SelfAttentionBlocks
150+
in the TransformerLayer(s). Length of the tuple has to equal
151+
`n_transformer_blocks`. Allowed styles: "basic". "slice", "flash",
152+
"memeff", "slice-memeff".
135153
transformer_biases : Tuple[Tuple[bool, ...], ...], default=((False,),)
136154
Flags, whether to use biases in the transformer layers. Ignored if
137155
`n_transformers` is None.
@@ -146,13 +164,13 @@ def __init__(
146164
Raises
147165
------
148166
ValueError:
149-
If lengths of the conv layer tuple args are not equal to `n_layers`.
167+
If lengths of the conv layer tuple args are not equal to `n_conv_layers`
150168
If lengths of the transformer layer tuple args are not equal to
151169
`n_transformers`.
152170
"""
153171
super().__init__()
154172

155-
self.n_layers = n_layers
173+
self.n_conv_layers = n_conv_layers
156174
self.n_transformers = n_transformers
157175
self.long_skip = long_skip
158176
self.stage_ix = stage_ix
@@ -176,19 +194,20 @@ def __init__(
176194

177195
# Set up n layers of conv blocks
178196
layer = None # placeholder
179-
if n_layers is not None:
197+
if n_conv_layers is not None:
180198

181199
# check that the conv-layer tuple-args are not illegal.
182200
self._check_tuple_args(
183201
"conv-layer related",
184-
"n_layers",
185-
n_layers,
202+
"n_conv_layers",
203+
n_conv_layers,
186204
all_args=locals(),
187205
skip_args=(
188206
skip_channels,
189207
dec_channels,
190208
dec_dims,
191-
self_attentions,
209+
transformer_blocks,
210+
transformer_computations,
192211
n_transformer_blocks,
193212
transformer_biases,
194213
transformer_dropouts,
@@ -197,12 +216,12 @@ def __init__(
197216

198217
# set up the conv-layers.
199218
self.conv_layers = nn.ModuleDict()
200-
for i in range(n_layers):
219+
for i in range(n_conv_layers):
201220
n_in_feats = self.skip.out_channels if i == 0 else layer.out_channels
202221
layer = ConvLayer(
203222
in_channels=n_in_feats,
204223
out_channels=self.out_channels,
205-
n_blocks=n_blocks[i],
224+
n_blocks=n_conv_blocks[i],
206225
layer_residual=layer_residual,
207226
style_channels=style_channels,
208227
short_skip=short_skips[i],
@@ -225,7 +244,9 @@ def __init__(
225244
self.out_channels = layer.out_channels
226245

227246
# set in_channels for final operations
228-
in_channels = self.skip.out_channels if n_layers is None else self.out_channels
247+
in_channels = (
248+
self.skip.out_channels if n_conv_layers is None else self.out_channels
249+
)
229250

230251
if n_transformers is not None:
231252

@@ -239,7 +260,7 @@ def __init__(
239260
skip_channels,
240261
dec_channels,
241262
dec_dims,
242-
n_blocks,
263+
n_conv_blocks,
243264
short_skips,
244265
expand_ratios,
245266
block_types,
@@ -262,7 +283,8 @@ def __init__(
262283
tr = Transformer2D(
263284
in_channels=in_channels,
264285
n_blocks=n_transformer_blocks[i],
265-
block_types=self_attentions[i],
286+
block_types=transformer_blocks[i],
287+
computation_types=transformer_computations[i],
266288
biases=transformer_biases[i],
267289
dropouts=transformer_dropouts[i],
268290
**transformer_params
@@ -272,7 +294,7 @@ def __init__(
272294
self.transformers[f"tr_layer_{i + 1}"] = tr
273295

274296
# add a channel pooling layer at the end if no conv-layers are set up
275-
if n_layers is None:
297+
if n_conv_layers is None:
276298
self.ch_pool = ChannelPool(
277299
in_channels=in_channels,
278300
out_channels=self.out_channels,
@@ -338,7 +360,7 @@ def forward(
338360
x = x[0] if self.long_skip == "unetpp" else x
339361

340362
# conv layers
341-
if self.n_layers is not None:
363+
if self.n_conv_layers is not None:
342364
for conv_layer in self.conv_layers.values():
343365
x = conv_layer(x, style) # (B, out_channels, H, W)
344366

@@ -348,7 +370,7 @@ def forward(
348370
x = transformer(x) # (B, long_skip_channels/out_channels, H, W)
349371

350372
# channel pool if conv-layers are skipped.
351-
if self.n_layers is None:
373+
if self.n_conv_layers is None:
352374
x = self.ch_pool(x) # (B, out_channels, H, W)
353375

354376
return x, extra_skips

cellseg_models_pytorch/models/cellpose/cellpose.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ def __init__(
174174
out_channels=out_channels,
175175
style_channels=style_channels,
176176
long_skip=long_skip,
177-
n_layers=n_layers,
178-
n_blocks=n_blocks,
177+
n_conv_layers=n_layers,
178+
n_conv_blocks=n_blocks,
179179
stage_params=dec_params[decoder_name],
180180
)
181181
self.add_module(f"{decoder_name}_decoder", decoder)

cellseg_models_pytorch/models/hovernet/hovernet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -168,8 +168,8 @@ def __init__(
168168
style_channels=style_channels,
169169
long_skip=long_skip,
170170
merge_policy=merge_policy,
171-
n_layers=n_layers,
172-
n_blocks=n_blocks,
171+
n_conv_layers=n_layers,
172+
n_conv_blocks=n_blocks,
173173
stage_params=dec_params[decoder_name],
174174
)
175175
self.add_module(f"{decoder_name}_decoder", decoder)

cellseg_models_pytorch/models/stardist/stardist.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ def __init__(
169169
out_channels=out_channels,
170170
style_channels=style_channels,
171171
long_skip=long_skip,
172-
n_layers=n_layers,
173-
n_blocks=n_blocks,
172+
n_conv_layers=n_layers,
173+
n_conv_blocks=n_blocks,
174174
stage_params=dec_params[decoder_name],
175175
)
176176
self.add_module(f"{decoder_name}_decoder", decoder)

cellseg_models_pytorch/transforms/functional/hovernet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def gen_hv_maps(inst_map: np.ndarray, min_size: int = 5) -> np.ndarray:
7878
inst = inst[y1:y2, x1:x2]
7979

8080
# instance center of mass, rounded to nearest pixel
81-
inst_com = list(ndi.measurements.center_of_mass(inst))
81+
inst_com = list(ndi.center_of_mass(inst))
8282
inst_com[0] = int(inst_com[0] + 0.5)
8383
inst_com[1] = int(inst_com[1] + 0.5)
8484

0 commit comments

Comments
 (0)