Skip to content
This repository was archived by the owner on Nov 21, 2023. It is now read-only.

Commit 678ffa6

Browse files
authored
First working fusion
- WIP : the models run but there are differences in the results
1 parent 2a21855 commit 678ffa6

File tree

1 file changed

+66
-97
lines changed

1 file changed

+66
-97
lines changed

tools/convert_pkl_to_pb.py

Lines changed: 66 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -338,8 +338,8 @@ def add_bbox_ops(args, net, blobs):
338338
new_ops.extend([op_nms])
339339
new_external_outputs.extend(['score_nms', 'bbox_nms', 'class_nms'])
340340

341-
net.Proto().op.extend(new_ops)
342-
net.Proto().external_output.extend(new_external_outputs)
341+
net.op.extend(new_ops)
342+
net.external_output.extend(new_external_outputs)
343343

344344

345345
def convert_model_gpu(args, net, init_net):
@@ -522,9 +522,8 @@ def _prepare_blobs(
522522
return blobs
523523

524524

525-
def run_model_pb(args, models_pb, im, check_blobs):
525+
def run_model_pb(args, net, init_net, im, check_blobs):
526526
workspace.ResetWorkspace()
527-
net, init_net = models_pb['net']
528527
workspace.RunNetOnce(init_net)
529528
mutils.create_input_blobs_for_net(net.Proto())
530529
workspace.CreateNet(net)
@@ -551,97 +550,37 @@ def run_model_pb(args, models_pb, im, check_blobs):
551550
classids = workspace.FetchBlob(core.ScopedName('class_nms'))
552551
boxes = workspace.FetchBlob(core.ScopedName('bbox_nms'))
553552
except Exception as e:
554-
print('Running pb model failed.\n{}'.format(e))
553+
logger.warn('Running pb model failed.\n{}'.format(e))
555554
R = 0
556555
scores = np.zeros((R,), dtype=np.float32)
557556
boxes = np.zeros((R, 4), dtype=np.float32)
558557
classids = np.zeros((R,), dtype=np.float32)
559558

560-
cls_keyps, cls_segms = None, None
561-
562-
if 'keypoint_net' in models_pb:
563-
keypoint_net, init_keypoint_net = models_pb['keypoint_net']
564-
workspace.RunNetOnce(init_keypoint_net)
565-
mutils.create_input_blobs_for_net(keypoint_net.Proto())
566-
keypoint_net.Proto().external_input.extend(['rpn_rois', 'bbox_pred', 'im_info', 'cls_prob'])
567-
workspace.CreateNet(keypoint_net)
568-
569-
im_scale = input_blobs['im_info'][0][2]
570-
input_blobs = {'keypoint_rois': test._get_rois_blob(boxes, im_scale)}
571-
572-
# Add multi-level rois for FPN
573-
if cfg.FPN.MULTILEVEL_ROIS:
574-
test._add_multilevel_rois_for_test(input_blobs, 'keypoint_rois')
575-
576-
gpu_blobs = []
577-
if args.device == 'gpu':
578-
gpu_blobs = ['data']
579-
for k, v in list(input_blobs.items()):
580-
workspace.FeedBlob(
581-
core.ScopedName(k),
582-
v,
583-
mutils.get_device_option_cuda() if k in gpu_blobs else
584-
mutils.get_device_option_cpu()
585-
)
586-
587-
try:
588-
workspace.RunNet(keypoint_net)
589-
pred_heatmaps = workspace.FetchBlob(core.ScopedName('kps_score')).squeeze()
590-
# In case of 1
591-
if pred_heatmaps.ndim == 3:
592-
pred_heatmaps = np.expand_dims(pred_heatmaps, axis=0)
593-
except Exception as e:
594-
print('Running pb model failed.\n{}'.format(e))
595-
R, M = 0, cfg.KRCNN.HEATMAP_SIZE
596-
pred_heatmaps = np.zeros((R, cfg.KRCNN.NUM_KEYPOINTS, M, M), np.float32)
559+
cls_segms, cls_keyps = None, None
597560

561+
if net.BlobIsDefined(core.ScopedName('kps_score')):
562+
pred_heatmaps = workspace.FetchBlob(core.ScopedName('kps_score')).squeeze()
563+
# In case of 1
564+
if pred_heatmaps.ndim == 3:
565+
pred_heatmaps = np.expand_dims(pred_heatmaps, axis=0)
598566
xy_preds = keypoint_utils.heatmaps_to_keypoints(pred_heatmaps, boxes)
599567
cls_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)]
600568
cls_keyps[1] = [xy_preds[i] for i in range(xy_preds.shape[0])]
569+
else:
570+
logger.info('Keypoint blob is not defined')
601571

602-
if 'mask_net' in models_pb:
603-
mask_net, init_mask_net = models_pb['mask_net']
604-
workspace.RunNetOnce(init_mask_net)
605-
mutils.create_input_blobs_for_net(mask_net.Proto())
606-
mask_net.Proto().external_input.extend(['rpn_rois', 'bbox_pred', 'im_info', 'cls_prob'])
607-
workspace.CreateNet(mask_net)
608-
609-
im_scale = input_blobs['im_info'][0][2]
610-
input_blobs = {'mask_rois': test._get_rois_blob(boxes, im_scale)}
611-
612-
# Add multi-level rois for FPN
613-
if cfg.FPN.MULTILEVEL_ROIS:
614-
test._add_multilevel_rois_for_test(input_blobs, 'mask_rois')
615-
616-
gpu_blobs = []
617-
if args.device == 'gpu':
618-
gpu_blobs = ['data']
619-
for k, v in list(input_blobs.items()):
620-
workspace.FeedBlob(
621-
core.ScopedName(k),
622-
v,
623-
mutils.get_device_option_cuda() if k in gpu_blobs else
624-
mutils.get_device_option_cpu()
625-
)
572+
if net.BlobIsDefined(core.ScopedName('mask_fcn_probs')):
573+
# Fetch masks
574+
pred_masks = workspace.FetchBlob(core.ScopedName('mask_fcn_probs')).squeeze()
626575
M = cfg.MRCNN.RESOLUTION
627-
try:
628-
workspace.RunNet(mask_net)
629-
# Fetch masks
630-
pred_masks = workspace.FetchBlob(core.ScopedName('mask_fcn_probs')).squeeze()
631-
if cfg.MRCNN.CLS_SPECIFIC_MASK:
632-
pred_masks = pred_masks.reshape([-1, cfg.MODEL.NUM_CLASSES, M, M])
633-
else:
634-
pred_masks = pred_masks.reshape([-1, 1, M, M])
635-
except Exception as e:
636-
print('Running pb model failed.\n{}'.format(e))
637-
R = 0
638-
if cfg.MRCNN.CLS_SPECIFIC_MASK:
639-
pred_masks = np.zeros((R, cfg.MODEL.NUM_CLASSES, M, M), dtype=np.float32)
640-
else:
641-
pred_masks = np.zeros((R, 1, M, M), dtype=np.float32)
642-
576+
if cfg.MRCNN.CLS_SPECIFIC_MASK:
577+
pred_masks = pred_masks.reshape([-1, cfg.MODEL.NUM_CLASSES, M, M])
578+
else:
579+
pred_masks = pred_masks.reshape([-1, 1, M, M])
643580
cls_boxes = [np.empty(list(classids).count(i)) for i in range(cfg.MODEL.NUM_CLASSES)]
644581
cls_segms = test.segm_results(cls_boxes, pred_masks, boxes, im.shape[0], im.shape[1])
582+
else:
583+
logger.info('Mask blob is not defined')
645584

646585
boxes = np.column_stack((boxes, scores))
647586

@@ -669,7 +608,7 @@ def _ornone(res):
669608
return ret
670609

671610

672-
def verify_model(args, models_pb, test_img_file):
611+
def verify_model(args, net, init_net, test_img_file):
673612
check_blobs = ['result_boxes', 'result_classids']
674613

675614
if cfg.MODEL.MASK_ON:
@@ -686,14 +625,14 @@ def _run_cfg_func(im, blobs):
686625
return run_model_cfg(args, im, check_blobs)
687626

688627
def _run_pb_func(im, blobs):
689-
return run_model_pb(args, models_pb, im, check_blobs)
628+
return run_model_pb(args, net, init_net, im, check_blobs)
690629

691630
print('Checking models...')
692631
assert mutils.compare_model(
693632
_run_cfg_func, _run_pb_func, test_img, check_blobs)
694633

695634

696-
def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
635+
def convert_to_pb(args, net, blobs, input_blobs):
697636
pb_net = core.Net('')
698637
pb_net.Proto().op.extend(copy.deepcopy(net.op))
699638

@@ -708,7 +647,7 @@ def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
708647
convert_net(args, pb_net.Proto(), blobs)
709648

710649
# add operators for bbox
711-
add_bbox_ops(args, pb_net, blobs)
650+
add_bbox_ops(args, pb_net.Proto(), blobs)
712651

713652
if args.fuse_af:
714653
print('Fusing affine channel...')
@@ -723,8 +662,8 @@ def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
723662
if args.device == 'gpu':
724663
[pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net)
725664

726-
pb_net.Proto().name = args.net_name + '_' + part_name
727-
pb_init_net.Proto().name = args.net_name + '_' + part_name + '_init'
665+
pb_net.Proto().name = args.net_name + '_net'
666+
pb_init_net.Proto().name = args.net_name + '_net_init'
728667

729668
return pb_net, pb_init_net
730669

@@ -743,25 +682,55 @@ def main():
743682
logger.info('Converting model with config:')
744683
logger.info(pprint.pformat(cfg))
745684

746-
models_pb = {}
747-
748685
# load model from cfg
749686
model, blobs = load_model(args)
750687

751688
input_net = ['data', 'im_info']
752-
models_pb['net'] = convert_to_pb(args, model.net.Proto(), blobs, input_blobs=input_net)
689+
690+
if cfg.MODEL.KEYPOINTS_ON:
691+
model_kps = model.keypoint_net.Proto()
692+
693+
# Connect rois blobs
694+
for op in model_kps.op:
695+
for i, input_name in enumerate(op.input):
696+
op.input[i] = input_name.replace("keypoint_rois", "rois")
697+
698+
# Remove external input defined in main net
699+
kps_external_input = []
700+
for i in model_kps.external_input:
701+
if not model.net.BlobIsDefined(i) and \
702+
not "keypoint_rois" in i:
703+
kps_external_input.append(i)
704+
705+
model.net.Proto().op.extend(model_kps.op)
706+
model.net.Proto().external_output.extend(model_kps.external_output)
707+
model.net.Proto().external_input.extend(kps_external_input)
753708

754709
if cfg.MODEL.MASK_ON:
755-
models_pb['mask_net'] = convert_to_pb(args, model.mask_net.Proto(), blobs, part_name='mask_net')
710+
model_mask = model.mask_net.Proto()
756711

757-
if cfg.MODEL.KEYPOINTS_ON:
758-
models_pb['keypoint_net'] = convert_to_pb(args, model.keypoint_net.Proto(), blobs, part_name='keypoint_net')
712+
# Connect rois blobs
713+
for op in model_mask.op:
714+
for i, input_name in enumerate(op.input):
715+
op.input[i] = input_name.replace("mask_rois", "rois")
716+
717+
# Remove external input defined in main net
718+
mask_external_input = []
719+
for i in model_mask.external_input:
720+
if not model.net.BlobIsDefined(i) and \
721+
not "mask_rois" in i:
722+
mask_external_input.append(i)
723+
724+
model.net.Proto().op.extend(model_mask.op)
725+
model.net.Proto().external_output.extend(model_mask.external_output)
726+
model.net.Proto().external_input.extend(mask_external_input)
727+
728+
net, init_net = convert_to_pb(args, model.net.Proto(), blobs, input_net)
759729

760-
for (pb_net, pb_init_net) in models_pb.values():
761-
_save_models(pb_net, pb_init_net, args)
730+
_save_models(net, init_net, args)
762731

763732
if args.test_img is not None:
764-
verify_model(args, models_pb, args.test_img)
733+
verify_model(args, net, init_net, args.test_img)
765734

766735
if __name__ == '__main__':
767736
main()

0 commit comments

Comments
 (0)