@@ -129,19 +129,19 @@ def __init__(self, image_size: int, num_classes: int):
129
129
final_out_channel = 3 * (4 + 1 + num_classes )
130
130
131
131
self .darknet53 = self .make_darknet53 ()
132
- self .conv_block1 = self .make_conv_block (1024 , 512 )
133
- self .conv_final1 = self .make_conv_final (512 , final_out_channel )
134
- self .yolo_layer1 = YOLODetection (anchors ['scale1' ], image_size , num_classes )
132
+ self .conv_block3 = self .make_conv_block (1024 , 512 )
133
+ self .conv_final3 = self .make_conv_final (512 , final_out_channel )
134
+ self .yolo_layer3 = YOLODetection (anchors ['scale1' ], image_size , num_classes )
135
135
136
- self .upsample1 = self .make_upsample (512 , 256 , scale_factor = 2 )
136
+ self .upsample2 = self .make_upsample (512 , 256 , scale_factor = 2 )
137
137
self .conv_block2 = self .make_conv_block (768 , 256 )
138
138
self .conv_final2 = self .make_conv_final (256 , final_out_channel )
139
139
self .yolo_layer2 = YOLODetection (anchors ['scale2' ], image_size , num_classes )
140
140
141
- self .upsample2 = self .make_upsample (256 , 128 , scale_factor = 2 )
142
- self .conv_block3 = self .make_conv_block (384 , 128 )
143
- self .conv_final3 = self .make_conv_final (128 , final_out_channel )
144
- self .yolo_layer3 = YOLODetection (anchors ['scale3' ], image_size , num_classes )
141
+ self .upsample1 = self .make_upsample (256 , 128 , scale_factor = 2 )
142
+ self .conv_block1 = self .make_conv_block (384 , 128 )
143
+ self .conv_final1 = self .make_conv_final (128 , final_out_channel )
144
+ self .yolo_layer1 = YOLODetection (anchors ['scale3' ], image_size , num_classes )
145
145
146
146
self .yolo_layers = [self .yolo_layer1 , self .yolo_layer2 , self .yolo_layer3 ]
147
147
@@ -163,23 +163,23 @@ def forward(self, x, targets=None):
163
163
residual_output [key ] = x
164
164
165
165
# Yolov3 layer forward
166
- conv_b1 = self .conv_block1 (residual_output ['residual_5_4' ])
167
- scale1 = self .conv_final1 ( conv_b1 )
168
- yolo_output1 , layer_loss = self .yolo_layer1 ( scale1 , targets )
166
+ conv_block3 = self .conv_block3 (residual_output ['residual_5_4' ])
167
+ scale3 = self .conv_final3 ( conv_block3 )
168
+ yolo_output3 , layer_loss = self .yolo_layer3 ( scale3 , targets )
169
169
loss += layer_loss
170
170
171
- scale2 = self .upsample1 ( conv_b1 )
171
+ scale2 = self .upsample2 ( conv_block3 )
172
172
scale2 = torch .cat ((scale2 , residual_output ['residual_4_8' ]), dim = 1 )
173
- conv_b2 = self .conv_block2 (scale2 )
174
- scale2 = self .conv_final2 (conv_b2 )
173
+ conv_block2 = self .conv_block2 (scale2 )
174
+ scale2 = self .conv_final2 (conv_block2 )
175
175
yolo_output2 , layer_loss = self .yolo_layer2 (scale2 , targets )
176
176
loss += layer_loss
177
177
178
- scale3 = self .upsample2 ( conv_b2 )
179
- scale3 = torch .cat ((scale3 , residual_output ['residual_3_8' ]), dim = 1 )
180
- conv_b3 = self .conv_block3 ( scale3 )
181
- scale3 = self .conv_final3 ( conv_b3 )
182
- yolo_output3 , layer_loss = self .yolo_layer3 ( scale3 , targets )
178
+ scale1 = self .upsample1 ( conv_block2 )
179
+ scale1 = torch .cat ((scale1 , residual_output ['residual_3_8' ]), dim = 1 )
180
+ conv_block1 = self .conv_block1 ( scale1 )
181
+ scale1 = self .conv_final1 ( conv_block1 )
182
+ yolo_output1 , layer_loss = self .yolo_layer1 ( scale1 , targets )
183
183
loss += layer_loss
184
184
185
185
yolo_outputs = [yolo_output1 , yolo_output2 , yolo_output3 ]
@@ -288,17 +288,17 @@ def load_darknet_weights(self, weights_path: str):
288
288
289
289
# Load YOLOv3 weights
290
290
if weights_path .find ('yolov3.weights' ) != - 1 :
291
- for module in self .conv_block1 :
291
+ for module in self .conv_block3 :
292
292
ptr = self .load_bn_weights (module [1 ], weights , ptr )
293
293
ptr = self .load_conv_weights (module [0 ], weights , ptr )
294
294
295
- ptr = self .load_bn_weights (self .conv_final1 [0 ][1 ], weights , ptr )
296
- ptr = self .load_conv_weights (self .conv_final1 [0 ][0 ], weights , ptr )
297
- ptr = self .load_conv_bias (self .conv_final1 [1 ], weights , ptr )
298
- ptr = self .load_conv_weights (self .conv_final1 [1 ], weights , ptr )
295
+ ptr = self .load_bn_weights (self .conv_final3 [0 ][1 ], weights , ptr )
296
+ ptr = self .load_conv_weights (self .conv_final3 [0 ][0 ], weights , ptr )
297
+ ptr = self .load_conv_bias (self .conv_final3 [1 ], weights , ptr )
298
+ ptr = self .load_conv_weights (self .conv_final3 [1 ], weights , ptr )
299
299
300
- ptr = self .load_bn_weights (self .upsample1 [0 ][1 ], weights , ptr )
301
- ptr = self .load_conv_weights (self .upsample1 [0 ][0 ], weights , ptr )
300
+ ptr = self .load_bn_weights (self .upsample2 [0 ][1 ], weights , ptr )
301
+ ptr = self .load_conv_weights (self .upsample2 [0 ][0 ], weights , ptr )
302
302
303
303
for module in self .conv_block2 :
304
304
ptr = self .load_bn_weights (module [1 ], weights , ptr )
@@ -309,17 +309,17 @@ def load_darknet_weights(self, weights_path: str):
309
309
ptr = self .load_conv_bias (self .conv_final2 [1 ], weights , ptr )
310
310
ptr = self .load_conv_weights (self .conv_final2 [1 ], weights , ptr )
311
311
312
- ptr = self .load_bn_weights (self .upsample2 [0 ][1 ], weights , ptr )
313
- ptr = self .load_conv_weights (self .upsample2 [0 ][0 ], weights , ptr )
312
+ ptr = self .load_bn_weights (self .upsample1 [0 ][1 ], weights , ptr )
313
+ ptr = self .load_conv_weights (self .upsample1 [0 ][0 ], weights , ptr )
314
314
315
- for module in self .conv_block3 :
315
+ for module in self .conv_block1 :
316
316
ptr = self .load_bn_weights (module [1 ], weights , ptr )
317
317
ptr = self .load_conv_weights (module [0 ], weights , ptr )
318
318
319
- ptr = self .load_bn_weights (self .conv_final3 [0 ][1 ], weights , ptr )
320
- ptr = self .load_conv_weights (self .conv_final3 [0 ][0 ], weights , ptr )
321
- ptr = self .load_conv_bias (self .conv_final3 [1 ], weights , ptr )
322
- ptr = self .load_conv_weights (self .conv_final3 [1 ], weights , ptr )
319
+ ptr = self .load_bn_weights (self .conv_final1 [0 ][1 ], weights , ptr )
320
+ ptr = self .load_conv_weights (self .conv_final1 [0 ][0 ], weights , ptr )
321
+ ptr = self .load_conv_bias (self .conv_final1 [1 ], weights , ptr )
322
+ ptr = self .load_conv_weights (self .conv_final1 [1 ], weights , ptr )
323
323
324
324
# Load BN bias, weights, running mean and running variance
325
325
def load_bn_weights (self , bn_layer , weights , ptr : int ):
0 commit comments