7
7
from utils .heatmap import imshowAtt
8
8
import config .yolov4_config as cfg
9
9
import time
10
-
10
+ import multiprocessing
11
+ from multiprocessing .dummy import Pool as ThreadPool # 线程池
12
+ from collections import defaultdict
11
13
current_milli_time = lambda : int (round (time .time () * 1000 ))
12
14
13
15
@@ -28,11 +30,14 @@ def __init__(self, model=None, showatt=False):
28
30
self .val_shape = cfg .VAL ["TEST_IMG_SIZE" ]
29
31
self .model = model
30
32
self .device = next (model .parameters ()).device
31
- self .__visual_imgs = 0
33
+ self .visual_imgs = 0
34
+ self .multi_scale_test = cfg .VAL ["MULTI_SCALE_VAL" ]
35
+ self .flip_test = cfg .VAL ["FLIP_VAL" ]
32
36
self .showatt = showatt
33
37
self .inference_time = 0.0
38
+ self .final_result = defaultdict (list )
34
39
35
- def APs_voc (self , multi_test = False , flip_test = False ):
40
+ def APs_voc (self ):
36
41
img_inds_file = os .path .join (
37
42
self .val_data_path , "ImageSets" , "Main" , "test.txt"
38
43
)
@@ -47,47 +52,47 @@ def APs_voc(self, multi_test=False, flip_test=False):
47
52
if not os .path .exists (output_path ):
48
53
os .mkdir (output_path )
49
54
os .mkdir (self .pred_result_path )
50
- for img_ind in tqdm (img_inds ):
51
- img_path = os .path .join (
52
- self .val_data_path , "JPEGImages" , img_ind + ".jpg"
53
- )
54
- img = cv2 .imread (img_path )
55
- bboxes_prd = self .get_bbox (img , multi_test , flip_test )
56
-
57
- f = open ("./output/" + img_ind + ".txt" , "w" )
58
- for bbox in bboxes_prd :
59
- coor = np .array (bbox [:4 ], dtype = np .int32 )
60
- score = bbox [4 ]
61
- class_ind = int (bbox [5 ])
62
-
63
- class_name = self .classes [class_ind ]
64
- score = "%.4f" % score
65
- xmin , ymin , xmax , ymax = map (str , coor )
66
- s = " " .join ([img_ind , score , xmin , ymin , xmax , ymax ]) + "\n "
67
-
68
- with open (
69
- os .path .join (
70
- self .pred_result_path ,
71
- "comp4_det_test_" + class_name + ".txt" ,
72
- ),
73
- "a" ,
74
- ) as r :
75
- r .write (s )
76
- f .write (
77
- "%s %s %s %s %s %s\n "
78
- % (
79
- class_name ,
80
- score ,
81
- str (xmin ),
82
- str (ymin ),
83
- str (xmax ),
84
- str (ymax ),
85
- )
86
- )
87
- f .close ()
55
+ imgs_count = len (img_inds )
56
+ cpu_nums = multiprocessing .cpu_count ()
57
+ pool = ThreadPool (cpu_nums )
58
+ with tqdm (total = imgs_count ) as pbar :
59
+ for i , _ in enumerate (pool .imap_unordered (self .Single_APs_voc , img_inds )):
60
+ pbar .update ()
61
+ for class_name in self .final_result :
62
+ with open (os .path .join (self .pred_result_path , 'comp4_det_test_' + class_name + '.txt' ), 'a' ) as f :
63
+ str_result = '' .join (self .final_result [class_name ])
64
+ f .write (str_result )
88
65
self .inference_time = 1.0 * self .inference_time / len (img_inds )
89
66
return self .__calc_APs (), self .inference_time
90
67
68
+ def Single_APs_voc (self , img_ind ):
69
+ img_path = os .path .join (self .val_data_path , 'JPEGImages' , img_ind + '.jpg' )
70
+ img = cv2 .imread (img_path )
71
+ bboxes_prd = self .get_bbox (img , self .multi_scale_test , self .flip_test )
72
+
73
+ if bboxes_prd .shape [0 ] != 0 and self .visual_imgs < 100 :
74
+ boxes = bboxes_prd [..., :4 ]
75
+ class_inds = bboxes_prd [..., 5 ].astype (np .int32 )
76
+ scores = bboxes_prd [..., 4 ]
77
+
78
+ visualize_boxes (image = img , boxes = boxes , labels = class_inds , probs = scores , class_labels = self .classes )
79
+ path = os .path .join (cfg .PROJECT_PATH , "data/results/{}.jpg" .format (self .visual_imgs ))
80
+ cv2 .imwrite (path , img )
81
+
82
+ self .visual_imgs += 1
83
+
84
+ for bbox in bboxes_prd :
85
+ coor = np .array (bbox [:4 ], dtype = np .int32 )
86
+ score = bbox [4 ]
87
+ class_ind = int (bbox [5 ])
88
+
89
+ class_name = self .classes [class_ind ]
90
+ score = '%.4f' % score
91
+ xmin , ymin , xmax , ymax = map (str , coor )
92
+ result = ' ' .join ([img_ind , score , xmin , ymin , xmax , ymax ]) + '\n '
93
+
94
+ self .final_result [class_name ].append (result )
95
+
91
96
def get_bbox (self , img , multi_test = False , flip_test = False , mode = None ):
92
97
if multi_test :
93
98
test_input_sizes = range (320 , 640 , 96 )
0 commit comments