Skip to content

Commit c05021e

Browse files
committed
style: minor docstring and code style fixes
1 parent 7a9b4a0 commit c05021e

File tree

8 files changed

+31
-24
lines changed

8 files changed

+31
-24
lines changed

cellseg_models_pytorch/decoders/decoder.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,14 @@ def __init__(
4949
The number of convolution layers inside each of the decoder stages. The
5050
argument can be given as a tuple, where each value indicates the number
5151
of conv-layers inside each stage of the decoder allowing the mixing of
52-
different sized layers inside the stages in the decoder. If set to None,
53-
no conv-layers will be included in the decoder.
52+
different sized layers inside the stages. If set to None, no conv-layers
53+
will be included in the decoder.
5454
n_transformers : Union[None, int, Tuple[int, ...]] , optional
5555
The number of transformer layers inside each of the decoder stages. The
5656
argument can be given as a tuple, where each value indicates the number
57-
of transformer-layers inside each stage of the decoder allowing the
58-
mixing of different sized layers inside the stages in the decoder. If
59-
set to None, no transformer layers will be included in the decoder.
57+
of transformer-layers inside each stage of the decoder stages allowing
58+
the mixing of different sized layers inside the stages. If set to None,
59+
no transformer layers will be included in the decoder.
6060
n_conv_blocks : Union[int, Tuple[Tuple[int, ...], ...]], default=2
6161
The number of blocks inside each conv-layer at each decoder stage. The
6262
argument can be given as a nested tuple, where each value indicates the

cellseg_models_pytorch/decoders/decoder_stage.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -109,9 +109,10 @@ def __init__(
109109
The tuple-length has to match `n_conv_layers`. Ignored if
110110
`n_conv_layers` is None.
111111
attentions : Tuple[Tuple[str, ...], ...], default: ((None, "se"), )
112-
Attention methods used inside the conv layers.
112+
Channel-attention methods used inside the conv layers.
113113
The tuple-length has to match `n_conv_layers`. Ignored if
114-
`n_conv_layers` is None.
114+
`n_conv_layers` is None. Allowed methods.: "se", "scse", "gc", "eca",
115+
"msca", None.
115116
preactivates Tuple[Tuple[bool, ...], ...], default: ((False, False), )
116117
Boolean flags for the conv layers to use pre-activation.
117118
The tuple-length has to match `n_conv_layers`. Ignored if

cellseg_models_pytorch/modules/conv_base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(
6464
bias : bool, default=False,
6565
Include bias term in the convolution.
6666
attention : str, default=None
67-
Attention method. One of: "se", "scse", "gc", "eca", None
67+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
6868
preattend : bool, default=False
6969
If True, Attention is applied at the beginning of forward pass.
7070
"""
@@ -186,7 +186,7 @@ def __init__(
186186
bias : bool, default=False,
187187
Include bias term in the convolution.
188188
attention : str, default=None
189-
Attention method. One of: "se", "scse", "gc", "eca", None
189+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
190190
preattend : bool, default=False
191191
If True, Attention is applied at the beginning of forward pass.
192192
"""
@@ -336,7 +336,7 @@ def __init__(
336336
kernel_size : int, default=3
337337
The size of the convolution kernel.
338338
attention : str, default=None
339-
Attention method. One of: "se", "scse", "gc", "eca", None
339+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
340340
preattend : bool, default=False
341341
If True, Attention is applied at the beginning of forward pass.
342342
"""
@@ -375,7 +375,7 @@ def __init__(
375375
self.act2 = Activation(activation)
376376

377377
def forward_features(self, x: torch.Tensor) -> torch.Tensor:
378-
"""Forward pass with pre-activation."""
378+
"""Forward pass."""
379379
if self.preattend:
380380
x = self.att(x)
381381

@@ -394,7 +394,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
394394
return x
395395

396396
def forward_features_preact(self, x: torch.Tensor) -> torch.Tensor:
397-
"""Forward pass."""
397+
"""Forward pass ith pre-activation."""
398398
if self.preattend:
399399
x = self.att(x)
400400

@@ -459,7 +459,7 @@ def __init__(
459459
kernel_size : int, default=3
460460
The size of the convolution kernel.
461461
attention : str, default=None
462-
Attention method. One of: "se", "scse", "gc", "eca", None
462+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
463463
preattend : bool, default=False
464464
If True, Attention is applied at the beginning of forward pass.
465465
"""
@@ -615,7 +615,7 @@ def __init__(
615615
preactivate : bool, default=False
616616
If True, normalization will be applied before convolution.
617617
attention : str, default=None
618-
Attention method. One of: "se", "scse", "gc", "eca", None
618+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
619619
preattend : bool, default=False
620620
If True, Attention is applied at the beginning of forward pass.
621621
"""
@@ -750,7 +750,7 @@ def __init__(
750750
kernel_size : int, default=3
751751
The size of the convolution kernel.
752752
attention : str, default=None
753-
Attention method. One of: "se", "scse", "gc", "eca", None
753+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
754754
preattend : bool, default=False
755755
If True, Attention is applied at the beginning of forward pass.
756756
"""
@@ -880,7 +880,7 @@ def __init__(
880880
bias : bool, default=False,
881881
Include bias term in the convolution.
882882
attention : str, default=None
883-
Attention method. One of: "se", "scse", "gc", "eca", None
883+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
884884
preattend : bool, default=False
885885
If True, Attention is applied at the beginning of forward pass.
886886
"""

cellseg_models_pytorch/modules/conv_block.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(
134134
bias : bool, default=True,
135135
Include bias term in the convolution block. Only used for `BaasicConv`.
136136
attention : str, default=None
137-
Attention method. One of: "se", "scse", "gc", "eca", None
137+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
138138
preattend : bool, default=False
139139
If True, Attention is applied at the beginning of forward pass.
140140
use_style : bool, default=False

cellseg_models_pytorch/modules/conv_layer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def __init__(
7373
biases : bool, default=(True, True)
7474
Include bias terms in the convolution blocks.
7575
attentions : Tuple[str, ...], default=(None, None)
76-
Attention method. One of: "se", "scse", "gc", "eca", None
76+
Attention method. One of: "se", "scse", "gc", "eca", "msca", None
7777
preattends : Tuple[bool, ...], default=(False, False)
7878
If True, Attention is applied at the beginning of forward pass.
7979
use_styles : bool, default=(False, False)

cellseg_models_pytorch/modules/misc_modules.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ def __init__(
2828
Flag, whether the scaling is an inplace operation.
2929
"""
3030
super().__init__()
31+
self.dim = dim
3132
self.inplace = inplace
3233
self.gamma = nn.Parameter(init_values * torch.ones(dim))
3334

3435
def forward(self, x: torch.Tensor) -> torch.Tensor:
3536
"""Forward pass of the layer scaling."""
3637
return x.mul_(self.gamma) if self.inplace else x * self.gamma
3738

39+
def extra_repr(self) -> str:
40+
"""Add extra to repr."""
41+
return f"dim={self.dim}, inplace={self.inplace}"
42+
3843

3944
class ChannelPool(nn.Module):
4045
def __init__(

cellseg_models_pytorch/modules/patch_embeddings.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,9 @@ def __init__(
5959
num_heads : int, default=8
6060
Number of heads in multi-head self-attention.
6161
flatten : bool, default=True
62-
If True, the output will be flattened to a sequence.
62+
If True, the output will be flattened to a sequence. After flattening
63+
output will have shape (B, H'*W', head_dim*num_heads). If False, the
64+
output shape will remain (B, C, H', W').
6365
normalization : str, optional
6466
The name of the normalization method.
6567
One of: "bn", "bcn", "gn", "in", "ln", "lrn", None
@@ -105,6 +107,7 @@ def __init__(
105107
self.kernel_size = patch_size if kernel_size is None else kernel_size
106108
self.pad = pad
107109
self.stride = stride
110+
norm_kwargs = norm_kwargs if norm_kwargs is not None else {}
108111

109112
self.norm = Norm(normalization, **norm_kwargs)
110113
self.proj = nn.Conv2d(

cellseg_models_pytorch/modules/transformers.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
135135
B, _, H, W = x.shape
136136
residual = x
137137

138-
# 1. project
138+
# 1. embed and project
139139
x = self.patch_embed(x)
140140

141141
# 2. transformer
@@ -151,7 +151,7 @@ def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor
151151
# NOTE: the kernel_size, pad, & stride has to be set correctly for this to work
152152
if p_H < H:
153153
scale_factor = H // p_H
154-
x = F.interpolate(x, scale_factor=scale_factor, mode="bilinear")
154+
x = F.interpolate(x, scale_factor=int(scale_factor), mode="bilinear")
155155

156156
# 4. project to original input channels
157157
x = self.proj_out(x)
@@ -268,15 +268,13 @@ def __init__(
268268
ls = LayerScale(query_dim) if layer_scales[i] else Identity()
269269
self.layer_scales.append(ls)
270270

271-
# self.tr_blocks[f"transformer_{block_types[i]}_{i + 1}"] = tr_block
272-
273271
self.mlp = MlpBlock(
274272
in_channels=query_dim,
275273
mlp_ratio=mlp_ratio,
276274
activation=activation,
277275
normalization="ln",
278276
norm_kwargs={"normalized_shape": query_dim},
279-
activation_kwargs=kwargs,
277+
act_kwargs=kwargs,
280278
)
281279

282280
def forward(self, x: torch.Tensor, context: torch.Tensor = None) -> torch.Tensor:

0 commit comments

Comments
 (0)