Skip to content

Commit af55aa0

Browse files
argusswiftargusswift
argusswift
authored and
argusswift
committed
multiprocessing test
1 parent c508af3 commit af55aa0

File tree

4 files changed

+52
-49
lines changed

4 files changed

+52
-49
lines changed

config/yolov4_config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
MODEL_TYPE = {
10-
"TYPE": "Mobilenetv3-YOLOv4"
10+
"TYPE": "Mobilenet-YOLOv4"
1111
} # YOLO type:YOLOv4, Mobilenet-YOLOv4 or Mobilenetv3-YOLOv4
1212

1313
CONV_TYPE = {"TYPE": "DO_CONV"} # conv type:DO_CONV or GENERAL
@@ -16,7 +16,7 @@
1616

1717
# train
1818
TRAIN = {
19-
"DATA_TYPE": "Customer", # DATA_TYPE: VOC ,COCO or Customer
19+
"DATA_TYPE": "VOC", # DATA_TYPE: VOC ,COCO or Customer
2020
"TRAIN_IMG_SIZE": 416,
2121
"AUGMENT": True,
2222
"BATCH_SIZE": 1,

eval/evaluator.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,9 @@
77
from utils.heatmap import imshowAtt
88
import config.yolov4_config as cfg
99
import time
10-
10+
import multiprocessing
11+
from multiprocessing.dummy import Pool as ThreadPool # 线程池
12+
from collections import defaultdict
1113
current_milli_time = lambda: int(round(time.time() * 1000))
1214

1315

@@ -28,11 +30,14 @@ def __init__(self, model=None, showatt=False):
2830
self.val_shape = cfg.VAL["TEST_IMG_SIZE"]
2931
self.model = model
3032
self.device = next(model.parameters()).device
31-
self.__visual_imgs = 0
33+
self.visual_imgs = 0
34+
self.multi_scale_test = cfg.VAL["MULTI_SCALE_VAL"]
35+
self.flip_test = cfg.VAL["FLIP_VAL"]
3236
self.showatt = showatt
3337
self.inference_time = 0.0
38+
self.final_result = defaultdict(list)
3439

35-
def APs_voc(self, multi_test=False, flip_test=False):
40+
def APs_voc(self):
3641
img_inds_file = os.path.join(
3742
self.val_data_path, "ImageSets", "Main", "test.txt"
3843
)
@@ -47,47 +52,47 @@ def APs_voc(self, multi_test=False, flip_test=False):
4752
if not os.path.exists(output_path):
4853
os.mkdir(output_path)
4954
os.mkdir(self.pred_result_path)
50-
for img_ind in tqdm(img_inds):
51-
img_path = os.path.join(
52-
self.val_data_path, "JPEGImages", img_ind + ".jpg"
53-
)
54-
img = cv2.imread(img_path)
55-
bboxes_prd = self.get_bbox(img, multi_test, flip_test)
56-
57-
f = open("./output/" + img_ind + ".txt", "w")
58-
for bbox in bboxes_prd:
59-
coor = np.array(bbox[:4], dtype=np.int32)
60-
score = bbox[4]
61-
class_ind = int(bbox[5])
62-
63-
class_name = self.classes[class_ind]
64-
score = "%.4f" % score
65-
xmin, ymin, xmax, ymax = map(str, coor)
66-
s = " ".join([img_ind, score, xmin, ymin, xmax, ymax]) + "\n"
67-
68-
with open(
69-
os.path.join(
70-
self.pred_result_path,
71-
"comp4_det_test_" + class_name + ".txt",
72-
),
73-
"a",
74-
) as r:
75-
r.write(s)
76-
f.write(
77-
"%s %s %s %s %s %s\n"
78-
% (
79-
class_name,
80-
score,
81-
str(xmin),
82-
str(ymin),
83-
str(xmax),
84-
str(ymax),
85-
)
86-
)
87-
f.close()
55+
imgs_count = len(img_inds)
56+
cpu_nums = multiprocessing.cpu_count()
57+
pool = ThreadPool(cpu_nums)
58+
with tqdm(total=imgs_count) as pbar:
59+
for i, _ in enumerate(pool.imap_unordered(self.Single_APs_voc, img_inds)):
60+
pbar.update()
61+
for class_name in self.final_result:
62+
with open(os.path.join(self.pred_result_path, 'comp4_det_test_' + class_name + '.txt'), 'a') as f:
63+
str_result = ''.join(self.final_result[class_name])
64+
f.write(str_result)
8865
self.inference_time = 1.0 * self.inference_time / len(img_inds)
8966
return self.__calc_APs(), self.inference_time
9067

68+
def Single_APs_voc(self, img_ind):
69+
img_path = os.path.join(self.val_data_path, 'JPEGImages', img_ind + '.jpg')
70+
img = cv2.imread(img_path)
71+
bboxes_prd = self.get_bbox(img, self.multi_scale_test, self.flip_test)
72+
73+
if bboxes_prd.shape[0] != 0 and self.visual_imgs < 100:
74+
boxes = bboxes_prd[..., :4]
75+
class_inds = bboxes_prd[..., 5].astype(np.int32)
76+
scores = bboxes_prd[..., 4]
77+
78+
visualize_boxes(image=img, boxes=boxes, labels=class_inds, probs=scores, class_labels=self.classes)
79+
path = os.path.join(cfg.PROJECT_PATH, "data/results/{}.jpg".format(self.visual_imgs))
80+
cv2.imwrite(path, img)
81+
82+
self.visual_imgs += 1
83+
84+
for bbox in bboxes_prd:
85+
coor = np.array(bbox[:4], dtype=np.int32)
86+
score = bbox[4]
87+
class_ind = int(bbox[5])
88+
89+
class_name = self.classes[class_ind]
90+
score = '%.4f' % score
91+
xmin, ymin, xmax, ymax = map(str, coor)
92+
result = ' '.join([img_ind, score, xmin, ymin, xmax, ymax]) + '\n'
93+
94+
self.final_result[class_name].append(result)
95+
9196
def get_bbox(self, img, multi_test=False, flip_test=False, mode=None):
9297
if multi_test:
9398
test_input_sizes = range(320, 640, 96)

eval_voc.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,6 @@ def __init__(
2525
self.__conf_threshold = cfg.VAL["CONF_THRESH"]
2626
self.__nms_threshold = cfg.VAL["NMS_THRESH"]
2727
self.__device = gpu.select_device(gpu_id)
28-
self.__multi_scale_val = cfg.VAL["MULTI_SCALE_VAL"]
29-
self.__flip_val = cfg.VAL["FLIP_VAL"]
3028
self.__showatt = showatt
3129
self.__visiual = visiual
3230
self.__eval = eval
@@ -57,7 +55,7 @@ def val(self):
5755
with torch.no_grad():
5856
APs, inference_time = Evaluator(
5957
self.__model, showatt=False
60-
).APs_voc(self.__multi_scale_val, self.__flip_val)
58+
).APs_voc()
6159
for i in APs:
6260
logger.info("{} --> mAP : {}".format(i, APs[i]))
6361
mAP += APs[i]

train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def train(self):
318318
parser.add_argument(
319319
"--weight_path",
320320
type=str,
321-
default="weight/mobilenetv3.pth",
321+
default="weight/mobilenetv2.pth",
322322
help="weight file path",
323323
) # weight/darknet53_448.weights
324324
parser.add_argument(
@@ -330,7 +330,7 @@ def train(self):
330330
parser.add_argument(
331331
"--gpu_id",
332332
type=int,
333-
default=-1,
333+
default=0,
334334
help="whither use GPU(0) or CPU(-1)",
335335
)
336336
parser.add_argument("--log_path", type=str, default="log/", help="log path")
@@ -349,7 +349,7 @@ def train(self):
349349
parser.add_argument(
350350
"--showatt",
351351
type=bool,
352-
default=True,
352+
default=False,
353353
help="whether to show attention map"
354354
)
355355
opt = parser.parse_args()

0 commit comments

Comments
 (0)