Skip to content

Commit 23a3c59

Browse files
committed
feat(modules): Add msca conv-attention
1 parent ad90d6d commit 23a3c59

File tree

1 file changed

+64
-0
lines changed

1 file changed

+64
-0
lines changed

cellseg_models_pytorch/modules/attention_modules.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,72 @@
3636
"SCSqueezeAndExcite",
3737
"ECA",
3838
"GlobalContext",
39+
"MSCA",
3940
]
4041

4142

43+
class MSCA(nn.Module):
44+
def __init__(self, in_channels: int, **kwargs) -> None:
45+
"""Multi-scale convolutional attention (MSCA).
46+
47+
- SegNeXt: http://arxiv.org/abs/2209.08575
48+
49+
Parameters
50+
----------
51+
in_channels : int
52+
The number of input channels.
53+
"""
54+
super().__init__()
55+
# depth-wise projection
56+
self.proj = nn.Conv2d(
57+
in_channels, in_channels, 5, padding=2, groups=in_channels
58+
)
59+
60+
# scale1
61+
self.conv0_1 = nn.Conv2d(
62+
in_channels, in_channels, (1, 7), padding=(0, 3), groups=in_channels
63+
)
64+
self.conv0_2 = nn.Conv2d(
65+
in_channels, in_channels, (7, 1), padding=(3, 0), groups=in_channels
66+
)
67+
68+
# scale2
69+
self.conv1_1 = nn.Conv2d(
70+
in_channels, in_channels, (1, 11), padding=(0, 5), groups=in_channels
71+
)
72+
self.conv1_2 = nn.Conv2d(
73+
in_channels, in_channels, (11, 1), padding=(5, 0), groups=in_channels
74+
)
75+
76+
# scale3
77+
self.conv2_1 = nn.Conv2d(
78+
in_channels, in_channels, (1, 21), padding=(0, 10), groups=in_channels
79+
)
80+
self.conv2_2 = nn.Conv2d(
81+
in_channels, in_channels, (21, 1), padding=(10, 0), groups=in_channels
82+
)
83+
self.conv3 = nn.Conv2d(in_channels, in_channels, 1)
84+
85+
def forward(self, x: torch.Tensor) -> torch.Tensor:
86+
"""Forward pass of the MSCA-attention."""
87+
residual = x
88+
attn = self.proj(x)
89+
90+
attn_0 = self.conv0_1(attn)
91+
attn_0 = self.conv0_2(attn_0)
92+
93+
attn_1 = self.conv1_1(attn)
94+
attn_1 = self.conv1_2(attn_1)
95+
96+
attn_2 = self.conv2_1(attn)
97+
attn_2 = self.conv2_2(attn_2)
98+
attn = attn + attn_0 + attn_1 + attn_2
99+
100+
attn = self.conv3(attn)
101+
102+
return attn * residual
103+
104+
42105
class SqueezeAndExcite(nn.Module):
43106
def __init__(
44107
self,
@@ -304,6 +367,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
304367
"scse": SCSqueezeAndExcite,
305368
"eca": ECA,
306369
"gc": GlobalContext,
370+
"msca": MSCA,
307371
}
308372

309373

0 commit comments

Comments
 (0)