Skip to content

Commit c601691

Browse files
committed
add stanet
1 parent 7f501b8 commit c601691

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

change_detection_pytorch/stanet/BAM.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,11 @@ class BAM(nn.Module):
1010
def __init__(self, in_dim, ds=8, activation=nn.ReLU):
1111
super(BAM, self).__init__()
1212
self.chanel_in = in_dim
13-
self.key_channel = self.chanel_in //8
13+
self.key_channel = self.chanel_in // 8
1414
self.activation = activation
1515
self.ds = ds #
1616
self.pool = nn.AvgPool2d(self.ds)
17-
print('ds: ',ds)
17+
print('ds: ', ds)
1818
self.query_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
1919
self.key_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim // 8, kernel_size=1)
2020
self.value_conv = nn.Conv2d(in_channels=in_dim, out_channels=in_dim, kernel_size=1)
@@ -35,7 +35,7 @@ def forward(self, input):
3535
proj_query = self.query_conv(x).view(m_batchsize, -1, width * height).permute(0, 2, 1) # B X C X (N)/(ds*ds)
3636
proj_key = self.key_conv(x).view(m_batchsize, -1, width * height) # B X C x (*W*H)/(ds*ds)
3737
energy = torch.bmm(proj_query, proj_key) # transpose check
38-
energy = (self.key_channel**-.5) * energy
38+
energy = (self.key_channel ** -.5) * energy
3939

4040
attention = self.softmax(energy) # BX (N) X (N)/(ds*ds)/(ds*ds)
4141

@@ -44,9 +44,7 @@ def forward(self, input):
4444
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
4545
out = out.view(m_batchsize, C, width, height)
4646

47-
out = F.interpolate(out, [width*self.ds,height*self.ds])
47+
out = F.interpolate(out, [width * self.ds, height * self.ds])
4848
out = out + input
4949

5050
return out
51-
52-

change_detection_pytorch/stanet/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ class STANet(torch.nn.Module):
2727
https://www.mdpi.com/2072-4292/12/10/1662
2828
2929
"""
30+
3031
def __init__(
3132
self,
3233
encoder_name: str = "resnet",

0 commit comments

Comments
 (0)