@@ -10,11 +10,11 @@ class BAM(nn.Module):
10
10
def __init__ (self , in_dim , ds = 8 , activation = nn .ReLU ):
11
11
super (BAM , self ).__init__ ()
12
12
self .chanel_in = in_dim
13
- self .key_channel = self .chanel_in // 8
13
+ self .key_channel = self .chanel_in // 8
14
14
self .activation = activation
15
15
self .ds = ds #
16
16
self .pool = nn .AvgPool2d (self .ds )
17
- print ('ds: ' ,ds )
17
+ print ('ds: ' , ds )
18
18
self .query_conv = nn .Conv2d (in_channels = in_dim , out_channels = in_dim // 8 , kernel_size = 1 )
19
19
self .key_conv = nn .Conv2d (in_channels = in_dim , out_channels = in_dim // 8 , kernel_size = 1 )
20
20
self .value_conv = nn .Conv2d (in_channels = in_dim , out_channels = in_dim , kernel_size = 1 )
@@ -35,7 +35,7 @@ def forward(self, input):
35
35
proj_query = self .query_conv (x ).view (m_batchsize , - 1 , width * height ).permute (0 , 2 , 1 ) # B X C X (N)/(ds*ds)
36
36
proj_key = self .key_conv (x ).view (m_batchsize , - 1 , width * height ) # B X C x (*W*H)/(ds*ds)
37
37
energy = torch .bmm (proj_query , proj_key ) # transpose check
38
- energy = (self .key_channel ** - .5 ) * energy
38
+ energy = (self .key_channel ** - .5 ) * energy
39
39
40
40
attention = self .softmax (energy ) # BX (N) X (N)/(ds*ds)/(ds*ds)
41
41
@@ -44,9 +44,7 @@ def forward(self, input):
44
44
out = torch .bmm (proj_value , attention .permute (0 , 2 , 1 ))
45
45
out = out .view (m_batchsize , C , width , height )
46
46
47
- out = F .interpolate (out , [width * self .ds ,height * self .ds ])
47
+ out = F .interpolate (out , [width * self .ds , height * self .ds ])
48
48
out = out + input
49
49
50
50
return out
51
-
52
-
0 commit comments