Skip to content

Commit 550c9d5

Browse files
committed
atten
1 parent 1d398a2 commit 550c9d5

File tree

5 files changed

+11
-13
lines changed

5 files changed

+11
-13
lines changed

config/yolov4_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44
PROJECT_PATH = osp.abspath(osp.join(osp.dirname(__file__), '..'))
55

66
DATA_PATH = osp.join(PROJECT_PATH, 'data')
7-
# PROJECT_PATH = "E:\YOLOV4/data"
8-
# PROJECT_PATH = "E:\YOLOV4/"
97

108

119
MODEL_TYPE = {

eval/evaluator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def __predict(self, img, test_shape, valid_scale, mode):
122122
with torch.no_grad():
123123
start_time = current_milli_time()
124124
if self.showatt:
125-
_, p_d, beta = self.model(img)
125+
_, p_d, atten = self.model(img)
126126
else:
127127
_, p_d = self.model(img)
128128
self.inference_time += current_milli_time() - start_time
@@ -131,7 +131,7 @@ def __predict(self, img, test_shape, valid_scale, mode):
131131
pred_bbox, test_shape, (org_h, org_w), valid_scale
132132
)
133133
if self.showatt and len(img) and mode == 'det':
134-
self.__show_heatmap(beta, org_img)
134+
self.__show_heatmap(atten, org_img)
135135
return bboxes
136136

137137
def __show_heatmap(self, beta, img):

model/YOLOv4.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,14 @@ def __init__(self, weight_path=None, out_channels=255, resume=False, showatt=Fal
278278
self.predict_net = PredictNet(feature_channels, out_channels)
279279

280280
def forward(self, x):
281-
beta = None
281+
atten = None
282282
features = self.backbone(x)
283283
if self.showatt:
284-
features[-1], beta = self.attention(features[-1])
284+
features[-1], atten = self.attention(features[-1])
285285
features[-1] = self.spp(features[-1])
286286
features = self.panet(features)
287287
predicts = self.predict_net(features)
288-
return predicts, beta
288+
return predicts, atten
289289

290290

291291
if __name__ == "__main__":

model/build_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, weight_path=None, resume=False, showatt=False):
4848

4949
def forward(self, x):
5050
out = []
51-
[x_s, x_m, x_l], beta = self.__yolov4(x)
51+
[x_s, x_m, x_l], atten = self.__yolov4(x)
5252

5353
out.append(self.__head_s(x_s))
5454
out.append(self.__head_m(x_m))
@@ -60,7 +60,7 @@ def forward(self, x):
6060
else:
6161
p, p_d = list(zip(*out))
6262
if self.__showatt:
63-
return p, torch.cat(p_d, 0), beta
63+
return p, torch.cat(p_d, 0), atten
6464
return p, torch.cat(p_d, 0)
6565

6666

model/layers/global_context_block.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def spatial_pool(self, x):
4848
context_mask = self.softmax(context_mask)
4949
beta1 = context_mask
5050
beta2 = torch.transpose(beta1, 1, 2)
51-
beta = torch.matmul(beta2, beta1)
51+
atten = torch.matmul(beta2, beta1)
5252

5353
# [N, 1, H * W, 1]
5454
context_mask = context_mask.unsqueeze(3)
@@ -57,13 +57,13 @@ def spatial_pool(self, x):
5757
# [N, C, 1, 1]
5858
context = context.view(batch, channel, 1, 1)
5959

60-
return context, beta
60+
return context, atten
6161

6262
def forward(self, x):
6363
# [N, C, 1, 1]
64-
context, beta = self.spatial_pool(x)
64+
context, atten = self.spatial_pool(x)
6565
# [N, C, 1, 1]
6666
channel_add_term = self.channel_add_conv(context)
6767
out = x + channel_add_term
6868

69-
return out, beta
69+
return out, atten

0 commit comments

Comments
 (0)