|
36 | 36 | "SCSqueezeAndExcite",
|
37 | 37 | "ECA",
|
38 | 38 | "GlobalContext",
|
| 39 | + "MSCA", |
39 | 40 | ]
|
40 | 41 |
|
41 | 42 |
|
| 43 | +class MSCA(nn.Module): |
| 44 | + def __init__(self, in_channels: int, **kwargs) -> None: |
| 45 | + """Multi-scale convolutional attention (MSCA). |
| 46 | +
|
| 47 | + - SegNeXt: http://arxiv.org/abs/2209.08575 |
| 48 | +
|
| 49 | + Parameters |
| 50 | + ---------- |
| 51 | + in_channels : int |
| 52 | + The number of input channels. |
| 53 | + """ |
| 54 | + super().__init__() |
| 55 | + # depth-wise projection |
| 56 | + self.proj = nn.Conv2d( |
| 57 | + in_channels, in_channels, 5, padding=2, groups=in_channels |
| 58 | + ) |
| 59 | + |
| 60 | + # scale1 |
| 61 | + self.conv0_1 = nn.Conv2d( |
| 62 | + in_channels, in_channels, (1, 7), padding=(0, 3), groups=in_channels |
| 63 | + ) |
| 64 | + self.conv0_2 = nn.Conv2d( |
| 65 | + in_channels, in_channels, (7, 1), padding=(3, 0), groups=in_channels |
| 66 | + ) |
| 67 | + |
| 68 | + # scale2 |
| 69 | + self.conv1_1 = nn.Conv2d( |
| 70 | + in_channels, in_channels, (1, 11), padding=(0, 5), groups=in_channels |
| 71 | + ) |
| 72 | + self.conv1_2 = nn.Conv2d( |
| 73 | + in_channels, in_channels, (11, 1), padding=(5, 0), groups=in_channels |
| 74 | + ) |
| 75 | + |
| 76 | + # scale3 |
| 77 | + self.conv2_1 = nn.Conv2d( |
| 78 | + in_channels, in_channels, (1, 21), padding=(0, 10), groups=in_channels |
| 79 | + ) |
| 80 | + self.conv2_2 = nn.Conv2d( |
| 81 | + in_channels, in_channels, (21, 1), padding=(10, 0), groups=in_channels |
| 82 | + ) |
| 83 | + self.conv3 = nn.Conv2d(in_channels, in_channels, 1) |
| 84 | + |
| 85 | + def forward(self, x: torch.Tensor) -> torch.Tensor: |
| 86 | + """Forward pass of the MSCA-attention.""" |
| 87 | + residual = x |
| 88 | + attn = self.proj(x) |
| 89 | + |
| 90 | + attn_0 = self.conv0_1(attn) |
| 91 | + attn_0 = self.conv0_2(attn_0) |
| 92 | + |
| 93 | + attn_1 = self.conv1_1(attn) |
| 94 | + attn_1 = self.conv1_2(attn_1) |
| 95 | + |
| 96 | + attn_2 = self.conv2_1(attn) |
| 97 | + attn_2 = self.conv2_2(attn_2) |
| 98 | + attn = attn + attn_0 + attn_1 + attn_2 |
| 99 | + |
| 100 | + attn = self.conv3(attn) |
| 101 | + |
| 102 | + return attn * residual |
| 103 | + |
| 104 | + |
42 | 105 | class SqueezeAndExcite(nn.Module):
|
43 | 106 | def __init__(
|
44 | 107 | self,
|
@@ -304,6 +367,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
|
304 | 367 | "scse": SCSqueezeAndExcite,
|
305 | 368 | "eca": ECA,
|
306 | 369 | "gc": GlobalContext,
|
| 370 | + "msca": MSCA, |
307 | 371 | }
|
308 | 372 |
|
309 | 373 |
|
|
0 commit comments