Skip to content

Commit 10c6639

Browse files
committed
bug fix no viz
1 parent 8b7d990 commit 10c6639

7 files changed

+94
-60
lines changed

config.sample.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ WIDTH: 600 # OpenCV only supports 4:3 formats others will be
1818
HEIGHT: 600 # 600x600 leads to 640x480
1919
MAX_FRAMES: 5000 # only used if visualize==False
2020
FPS_INTERVAL: 5 # Interval [s] to print fps of the last interval in console
21-
DET_INTERVAL: 500 # intervall [frames] to print detections to console
22-
DET_TH: 0.5 # detection threshold for det_intervall
21+
PRINT_INTERVAL: 500 # intervall [frames] to print detections to console
22+
PRINT_TH: 0.5 # detection threshold for det_intervall
2323
## speed hack
2424
SPLIT_MODEL: True # Splits Model into a GPU and CPU session (currently only works for ssd_mobilenets)
2525
SSD_SHAPE: 300 # used for the split model algorithm (currently only supports ssd networks trained on 300x300 and 600x600 input)

rod/config.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class Config(object):
4141
HEIGHT = cfg['HEIGHT'] # 600x600 leads to 640x480
4242
MAX_FRAMES = cfg['MAX_FRAMES'] # only used if visualize==False
4343
FPS_INTERVAL = cfg['FPS_INTERVAL'] # Interval [s] to print fps of the last interval in console
44-
DET_INTERVAL = cfg['DET_INTERVAL'] # intervall [frames] to print detections to console
45-
DET_TH = cfg['DET_TH'] # detection threshold for det_intervall
44+
PRINT_INTERVAL = cfg['PRINT_INTERVAL'] # intervall [frames] to print detections to console
45+
PRINT_TH = cfg['PRINT_TH'] # detection threshold for det_intervall
4646
## speed hack
4747
SPLIT_MODEL = cfg['SPLIT_MODEL'] # Splits Model into a GPU and CPU session (currently only works for ssd_mobilenets)
4848
SSD_SHAPE = cfg['SSD_SHAPE'] # used for the split model algorithm (currently only supports ssd networks trained on 300x300 and 600x600 input)

rod/vis_utils.py

Lines changed: 71 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -315,12 +315,56 @@ def draw_text_on_image(image, string, position, color = (77, 255, 9)):
315315
cv2.FONT_HERSHEY_SIMPLEX, 0.75, color, 2)
316316

317317

318-
def visualize_objectdetection(image, boxes, classes, scores, masks, category_index,
319-
fps=None, visualize=False, det_interval=5, det_th=0.5,
320-
max_frames=500, cur_frame = None, model='realtime_object_detection'):
318+
def exit_visualization(millis):
321319
"""
322-
complicated visualization function for object detections
323-
TODO: CLEAN UP
320+
returns false if openCV exit is requested
321+
"""
322+
if cv2.waitKey(millis) & 0xFF == ord('q'):
323+
return False
324+
return True
325+
326+
def exit_print(cur_frame,max_frames):
327+
"""
328+
returns false if max frames are reached
329+
"""
330+
if cur_frame >= max_frames:
331+
return False
332+
return True
333+
334+
def print_detection(boxes,scores,classes,category_index,cur_frame,max_frames=500,print_interval=100,print_th=0.5):
335+
"""
336+
prints detection result above threshold to console
337+
"""
338+
for box, score, _class in zip(boxes, scores, classes):
339+
if cur_frame%print_interval==0 and score > print_th:
340+
label = category_index[_class]['name']
341+
print("label: {}\nscore: {}\nbox: {}".format(label, score, box))
342+
343+
def draw_single_box_on_image(image,box,label):
344+
"""
345+
draws single box and label on image
346+
"""
347+
p1 = (box[1], box[0])
348+
p2 = (box[3], box[2])
349+
cv2.rectangle(image, p1, p2, (77,255,9), 2)
350+
draw_text_on_image(image,label,(p1[0],p1[1]-10))
351+
352+
353+
def visualize_objectdetection(image,
354+
boxes,
355+
classes,
356+
scores,
357+
masks,
358+
category_index,
359+
cur_frame,
360+
max_frames=500,
361+
fps='N/A',
362+
print_interval=100,
363+
print_th=0.5,
364+
model_name='realtime_object_detection',
365+
visualize=True):
366+
"""
367+
visualization function for object_detection
324368
"""
325369
if visualize:
326370
visualize_boxes_and_labels_on_image(
@@ -332,44 +376,31 @@ def visualize_objectdetection(image, boxes, classes, scores, masks, category_ind
332376
instance_masks=masks,
333377
use_normalized_coordinates=True,
334378
line_thickness=2)
335-
if fps:
336-
draw_text_on_image(image,"fps: {}".format(fps), (10,30))
337-
cv2.imshow(model, image)
338-
elif not visualize and cur_frame:
339-
# Exit after max frames if no visualization
340-
for box, score, _class in zip(boxes, scores, classes):
341-
if cur_frame%det_interval==0 and score > det_th:
342-
label = category_index[_class]['name']
343-
print("label: {}\nscore: {}\nbox: {}".format(label, score, box))
344-
elif fps == "console":
345-
for box, score, _class in zip(boxes, scores, classes):
346-
if score > det_th:
347-
label = category_index[_class]['name']
348-
print("label: {}\nscore: {}\nbox: {}".format(label, score, box))
349-
# Exit Option
350-
if visualize:
351-
if cv2.waitKey(1) & 0xFF == ord('q'):
352-
return False
353-
elif not visualize and fps:
354-
if cur_frame >= max_frames:
355-
return False
356-
return True
357-
358-
def draw_single_box_on_image(image,box,label):
359-
p1 = (box[1], box[0])
360-
p2 = (box[3], box[2])
361-
cv2.rectangle(image, p1, p2, (77,255,9), 2)
362-
draw_text_on_image(image,label,(p1[0],p1[1]-10))
363-
379+
draw_text_on_image(image,"fps: {}".format(fps), (10,30))
380+
cv2.imshow(model_name, image)
381+
vis = exit_visualization(1)
382+
else:
383+
print_detection(boxes,scores,classes,category_index,cur_frame,max_frames,print_interval,print_th)
384+
vis = exit_print(cur_frame,max_frames)
385+
return vis
364386

365-
def visualize_deeplab(image,seg_map,model_name,fps,visualize=True):
387+
def visualize_deeplab(image,
388+
seg_map,
389+
cur_frame,
390+
max_frames=500,
391+
fps='N/A',
392+
print_interval=100,
393+
print_th=0.5,
394+
model_name='DeepLab',
395+
visualize=True):
396+
"""
397+
visualization function for deeplab
398+
"""
366399
if visualize:
367400
draw_mask_on_image(image, seg_map)
368401
draw_text_on_image(image,"fps: {}".format(fps),(10,30))
369402
cv2.imshow(model_name,image)
370-
if cv2.waitKey(1) & 0xFF == ord('q'):
371-
return False
372-
else:
373-
return True
403+
vis = exit_visualization(1)
374404
else:
375-
return True
405+
vis = exit_print(cur_frame,max_frames)
406+
return vis

run_deeplab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,8 @@ def segmentation(model,config):
4646
if config.VISUALIZE:
4747
draw_single_box_on_image(frame,box,label)
4848

49-
vis = visualize_deeplab(frame,seg_map,config.OD_MODEL_NAME+config._DEV+config._OPT,
50-
fps.fps_local(),config.VISUALIZE)
49+
vis = visualize_deeplab(frame,seg_map,fps._glob_numFrames,config.MAX_FRAMES,fps.fps_local(),
50+
config.PRINT_INTERVAL,config.PRINT_TH,config.OD_MODEL_NAME+config._DEV+config._OPT,config.VISUALIZE)
5151
if not vis:
5252
break
5353
fps.update()

run_objectdetection.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,9 @@ def detection(model,config):
120120
scores = np.squeeze(scores)
121121

122122
# Visualization
123-
vis = visualize_objectdetection(frame, boxes, classes, scores, masks, category_index, fps.fps_local(),
124-
config.VISUALIZE, config.DET_INTERVAL, config.DET_TH, config.MAX_FRAMES,
125-
fps._glob_numFrames, config.OD_MODEL_NAME)
123+
vis = visualize_objectdetection(frame,boxes,classes,scores,masks,category_index,fps._glob_numFrames,
124+
config.MAX_FRAMES,fps.fps_local(),config.PRINT_INTERVAL,config.PRINT_TH,
125+
config.OD_MODEL_NAME+config._DEV+config._OPT,config.VISUALIZE)
126126
if not vis:
127127
break
128128

@@ -146,9 +146,9 @@ def detection(model,config):
146146
for idx,tracker in enumerate(trackers):
147147
tracker_box = tracker.update(frame)
148148
tracker_boxes[idx,:] = conv_track2detect(tracker_box, vs.real_width, vs.real_height)
149-
vis = visualize_objectdetection(frame, tracker_boxes, classes, scores, masks, category_index, fps.fps_local(),
150-
config.VISUALIZE, config.DET_INTERVAL, config.DET_TH, config.MAX_FRAMES,
151-
fps._glob_numFrames, config.OD_MODEL_NAME+config._DEV+config._OPT)
149+
vis = visualize_objectdetection(frame,tracker_boxes,classes,scores,masks,category_index,fps._glob_numFrames,
150+
config.MAX_FRAMES,fps.fps_local(),config.PRINT_INTERVAL,config.PRINT_TH,
151+
config.OD_MODEL_NAME+config._DEV+config._OPT,config.VISUALIZE)
152152
if not vis:
153153
break
154154

@@ -168,8 +168,11 @@ def detection(model,config):
168168
cpu_worker.stop()
169169

170170

171-
if __name__ == '__main__':
171+
def main():
172172
config = Config()
173173
model = Model('od',config.OD_MODEL_NAME,config.OD_MODEL_PATH,config.LABEL_PATH,
174174
config.NUM_CLASSES,config.SPLIT_MODEL, config.SSD_SHAPE).prepare_od_model()
175175
detection(model, config)
176+
177+
if __name__ == '__main__':
178+
main()

test_deeplab.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,8 @@ def segmentation(model,config):
6363
if config.VISUALIZE:
6464
draw_single_box_on_image(frame,box,label)
6565

66-
vis = visualize_deeplab(frame,seg_map,config.OD_MODEL_NAME+config._DEV+config._OPT,
67-
timer.get_fps(),config.VISUALIZE)
66+
vis = visualize_deeplab(frame,seg_map,timer.get_frame(),config.MAX_FRAMES,timer.get_fps(),
67+
config.PRINT_INTERVAL,config.PRINT_TH,config.OD_MODEL_NAME+config._DEV+config._OPT,config.VISUALIZE)
6868
if not vis:
6969
break
7070
cv2.destroyAllWindows()

test_objectdetection.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from rod.helper import Timer, WebcamVideoStream, SessionWorker, TimeLiner, load_images
1515
from rod.model import Model
1616
from rod.config import Config
17-
from rod.utils import ops as utils_ops
1817
from rod.vis_utils import visualize_objectdetection
18+
from rod.tf_utils import reframe_box_masks_to_image_masks
1919

2020

2121
def detection(model,config):
@@ -38,7 +38,7 @@ def detection(model,config):
3838
real_num_detection = tf.cast(tensor_dict['num_detections'][0], tf.int32)
3939
detection_boxes = tf.slice(detection_boxes, [0, 0], [real_num_detection, -1])
4040
detection_masks = tf.slice(detection_masks, [0, 0, 0], [real_num_detection, -1, -1])
41-
detection_masks_reframed = utils_ops.reframe_box_masks_to_image_masks(
41+
detection_masks_reframed = reframe_box_masks_to_image_masks(
4242
detection_masks, detection_boxes, config.HEIGHT, config.WIDTH)
4343
detection_masks_reframed = tf.cast(tf.greater(detection_masks_reframed, 0.5), tf.uint8)
4444
# Follow the convention by adding back the batch dimension
@@ -120,9 +120,9 @@ def detection(model,config):
120120
scores = np.squeeze(scores)
121121

122122
# Visualization
123-
vis = visualize_objectdetection(frame, boxes, classes, scores, masks, category_index, timer.get_fps(),
124-
config.VISUALIZE, config.DET_INTERVAL, config.DET_TH, config.MAX_FRAMES, None,
125-
config.OD_MODEL_NAME+config._DEV+config._OPT)
123+
vis = visualize_objectdetection(frame,boxes,classes,scores,masks,category_index,timer.get_frame(),
124+
config.MAX_FRAMES,timer.get_fps(),config.PRINT_INTERVAL,config.PRINT_TH,
125+
config.OD_MODEL_NAME+config._DEV+config._OPT,config.VISUALIZE)
126126
if not vis:
127127
break
128128

0 commit comments

Comments
 (0)