@@ -338,8 +338,8 @@ def add_bbox_ops(args, net, blobs):
338338 new_ops .extend ([op_nms ])
339339 new_external_outputs .extend (['score_nms' , 'bbox_nms' , 'class_nms' ])
340340
341- net .Proto (). op .extend (new_ops )
342- net .Proto (). external_output .extend (new_external_outputs )
341+ net .op .extend (new_ops )
342+ net .external_output .extend (new_external_outputs )
343343
344344
345345def convert_model_gpu (args , net , init_net ):
@@ -522,9 +522,8 @@ def _prepare_blobs(
522522 return blobs
523523
524524
525- def run_model_pb (args , models_pb , im , check_blobs ):
525+ def run_model_pb (args , net , init_net , im , check_blobs ):
526526 workspace .ResetWorkspace ()
527- net , init_net = models_pb ['net' ]
528527 workspace .RunNetOnce (init_net )
529528 mutils .create_input_blobs_for_net (net .Proto ())
530529 workspace .CreateNet (net )
@@ -551,97 +550,37 @@ def run_model_pb(args, models_pb, im, check_blobs):
551550 classids = workspace .FetchBlob (core .ScopedName ('class_nms' ))
552551 boxes = workspace .FetchBlob (core .ScopedName ('bbox_nms' ))
553552 except Exception as e :
554- print ('Running pb model failed.\n {}' .format (e ))
553+ logger . warn ('Running pb model failed.\n {}' .format (e ))
555554 R = 0
556555 scores = np .zeros ((R ,), dtype = np .float32 )
557556 boxes = np .zeros ((R , 4 ), dtype = np .float32 )
558557 classids = np .zeros ((R ,), dtype = np .float32 )
559558
560- cls_keyps , cls_segms = None , None
561-
562- if 'keypoint_net' in models_pb :
563- keypoint_net , init_keypoint_net = models_pb ['keypoint_net' ]
564- workspace .RunNetOnce (init_keypoint_net )
565- mutils .create_input_blobs_for_net (keypoint_net .Proto ())
566- keypoint_net .Proto ().external_input .extend (['rpn_rois' , 'bbox_pred' , 'im_info' , 'cls_prob' ])
567- workspace .CreateNet (keypoint_net )
568-
569- im_scale = input_blobs ['im_info' ][0 ][2 ]
570- input_blobs = {'keypoint_rois' : test ._get_rois_blob (boxes , im_scale )}
571-
572- # Add multi-level rois for FPN
573- if cfg .FPN .MULTILEVEL_ROIS :
574- test ._add_multilevel_rois_for_test (input_blobs , 'keypoint_rois' )
575-
576- gpu_blobs = []
577- if args .device == 'gpu' :
578- gpu_blobs = ['data' ]
579- for k , v in list (input_blobs .items ()):
580- workspace .FeedBlob (
581- core .ScopedName (k ),
582- v ,
583- mutils .get_device_option_cuda () if k in gpu_blobs else
584- mutils .get_device_option_cpu ()
585- )
586-
587- try :
588- workspace .RunNet (keypoint_net )
589- pred_heatmaps = workspace .FetchBlob (core .ScopedName ('kps_score' )).squeeze ()
590- # In case of 1
591- if pred_heatmaps .ndim == 3 :
592- pred_heatmaps = np .expand_dims (pred_heatmaps , axis = 0 )
593- except Exception as e :
594- print ('Running pb model failed.\n {}' .format (e ))
595- R , M = 0 , cfg .KRCNN .HEATMAP_SIZE
596- pred_heatmaps = np .zeros ((R , cfg .KRCNN .NUM_KEYPOINTS , M , M ), np .float32 )
559+ cls_segms , cls_keyps = None , None
597560
561+ if net .BlobIsDefined (core .ScopedName ('kps_score' )):
562+ pred_heatmaps = workspace .FetchBlob (core .ScopedName ('kps_score' )).squeeze ()
563+ # In case of 1
564+ if pred_heatmaps .ndim == 3 :
565+ pred_heatmaps = np .expand_dims (pred_heatmaps , axis = 0 )
598566 xy_preds = keypoint_utils .heatmaps_to_keypoints (pred_heatmaps , boxes )
599567 cls_keyps = [[] for _ in range (cfg .MODEL .NUM_CLASSES )]
600568 cls_keyps [1 ] = [xy_preds [i ] for i in range (xy_preds .shape [0 ])]
569+ else :
570+ logger .info ('Keypoint blob is not defined' )
601571
602- if 'mask_net' in models_pb :
603- mask_net , init_mask_net = models_pb ['mask_net' ]
604- workspace .RunNetOnce (init_mask_net )
605- mutils .create_input_blobs_for_net (mask_net .Proto ())
606- mask_net .Proto ().external_input .extend (['rpn_rois' , 'bbox_pred' , 'im_info' , 'cls_prob' ])
607- workspace .CreateNet (mask_net )
608-
609- im_scale = input_blobs ['im_info' ][0 ][2 ]
610- input_blobs = {'mask_rois' : test ._get_rois_blob (boxes , im_scale )}
611-
612- # Add multi-level rois for FPN
613- if cfg .FPN .MULTILEVEL_ROIS :
614- test ._add_multilevel_rois_for_test (input_blobs , 'mask_rois' )
615-
616- gpu_blobs = []
617- if args .device == 'gpu' :
618- gpu_blobs = ['data' ]
619- for k , v in list (input_blobs .items ()):
620- workspace .FeedBlob (
621- core .ScopedName (k ),
622- v ,
623- mutils .get_device_option_cuda () if k in gpu_blobs else
624- mutils .get_device_option_cpu ()
625- )
572+ if net .BlobIsDefined (core .ScopedName ('mask_fcn_probs' )):
573+ # Fetch masks
574+ pred_masks = workspace .FetchBlob (core .ScopedName ('mask_fcn_probs' )).squeeze ()
626575 M = cfg .MRCNN .RESOLUTION
627- try :
628- workspace .RunNet (mask_net )
629- # Fetch masks
630- pred_masks = workspace .FetchBlob (core .ScopedName ('mask_fcn_probs' )).squeeze ()
631- if cfg .MRCNN .CLS_SPECIFIC_MASK :
632- pred_masks = pred_masks .reshape ([- 1 , cfg .MODEL .NUM_CLASSES , M , M ])
633- else :
634- pred_masks = pred_masks .reshape ([- 1 , 1 , M , M ])
635- except Exception as e :
636- print ('Running pb model failed.\n {}' .format (e ))
637- R = 0
638- if cfg .MRCNN .CLS_SPECIFIC_MASK :
639- pred_masks = np .zeros ((R , cfg .MODEL .NUM_CLASSES , M , M ), dtype = np .float32 )
640- else :
641- pred_masks = np .zeros ((R , 1 , M , M ), dtype = np .float32 )
642-
576+ if cfg .MRCNN .CLS_SPECIFIC_MASK :
577+ pred_masks = pred_masks .reshape ([- 1 , cfg .MODEL .NUM_CLASSES , M , M ])
578+ else :
579+ pred_masks = pred_masks .reshape ([- 1 , 1 , M , M ])
643580 cls_boxes = [np .empty (list (classids ).count (i )) for i in range (cfg .MODEL .NUM_CLASSES )]
644581 cls_segms = test .segm_results (cls_boxes , pred_masks , boxes , im .shape [0 ], im .shape [1 ])
582+ else :
583+ logger .info ('Mask blob is not defined' )
645584
646585 boxes = np .column_stack ((boxes , scores ))
647586
@@ -669,7 +608,7 @@ def _ornone(res):
669608 return ret
670609
671610
672- def verify_model (args , models_pb , test_img_file ):
611+ def verify_model (args , net , init_net , test_img_file ):
673612 check_blobs = ['result_boxes' , 'result_classids' ]
674613
675614 if cfg .MODEL .MASK_ON :
@@ -686,14 +625,14 @@ def _run_cfg_func(im, blobs):
686625 return run_model_cfg (args , im , check_blobs )
687626
688627 def _run_pb_func (im , blobs ):
689- return run_model_pb (args , models_pb , im , check_blobs )
628+ return run_model_pb (args , net , init_net , im , check_blobs )
690629
691630 print ('Checking models...' )
692631 assert mutils .compare_model (
693632 _run_cfg_func , _run_pb_func , test_img , check_blobs )
694633
695634
696- def convert_to_pb (args , net , blobs , part_name = 'net' , input_blobs = [] ):
635+ def convert_to_pb (args , net , blobs , input_blobs ):
697636 pb_net = core .Net ('' )
698637 pb_net .Proto ().op .extend (copy .deepcopy (net .op ))
699638
@@ -708,7 +647,7 @@ def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
708647 convert_net (args , pb_net .Proto (), blobs )
709648
710649 # add operators for bbox
711- add_bbox_ops (args , pb_net , blobs )
650+ add_bbox_ops (args , pb_net . Proto () , blobs )
712651
713652 if args .fuse_af :
714653 print ('Fusing affine channel...' )
@@ -723,8 +662,8 @@ def convert_to_pb(args, net, blobs, part_name='net', input_blobs=[]):
723662 if args .device == 'gpu' :
724663 [pb_net , pb_init_net ] = convert_model_gpu (args , pb_net , pb_init_net )
725664
726- pb_net .Proto ().name = args .net_name + '_' + part_name
727- pb_init_net .Proto ().name = args .net_name + '_' + part_name + '_init '
665+ pb_net .Proto ().name = args .net_name + '_net'
666+ pb_init_net .Proto ().name = args .net_name + '_net_init '
728667
729668 return pb_net , pb_init_net
730669
@@ -743,25 +682,55 @@ def main():
743682 logger .info ('Converting model with config:' )
744683 logger .info (pprint .pformat (cfg ))
745684
746- models_pb = {}
747-
748685 # load model from cfg
749686 model , blobs = load_model (args )
750687
751688 input_net = ['data' , 'im_info' ]
752- models_pb ['net' ] = convert_to_pb (args , model .net .Proto (), blobs , input_blobs = input_net )
689+
690+ if cfg .MODEL .KEYPOINTS_ON :
691+ model_kps = model .keypoint_net .Proto ()
692+
693+ # Connect rois blobs
694+ for op in model_kps .op :
695+ for i , input_name in enumerate (op .input ):
696+ op .input [i ] = input_name .replace ("keypoint_rois" , "rois" )
697+
698+ # Remove external input defined in main net
699+ kps_external_input = []
700+ for i in model_kps .external_input :
701+ if not model .net .BlobIsDefined (i ) and \
702+ not "keypoint_rois" in i :
703+ kps_external_input .append (i )
704+
705+ model .net .Proto ().op .extend (model_kps .op )
706+ model .net .Proto ().external_output .extend (model_kps .external_output )
707+ model .net .Proto ().external_input .extend (kps_external_input )
753708
754709 if cfg .MODEL .MASK_ON :
755- models_pb [ 'mask_net' ] = convert_to_pb ( args , model .mask_net .Proto (), blobs , part_name = 'mask_net' )
710+ model_mask = model .mask_net .Proto ()
756711
757- if cfg .MODEL .KEYPOINTS_ON :
758- models_pb ['keypoint_net' ] = convert_to_pb (args , model .keypoint_net .Proto (), blobs , part_name = 'keypoint_net' )
712+ # Connect rois blobs
713+ for op in model_mask .op :
714+ for i , input_name in enumerate (op .input ):
715+ op .input [i ] = input_name .replace ("mask_rois" , "rois" )
716+
717+ # Remove external input defined in main net
718+ mask_external_input = []
719+ for i in model_mask .external_input :
720+ if not model .net .BlobIsDefined (i ) and \
721+ not "mask_rois" in i :
722+ mask_external_input .append (i )
723+
724+ model .net .Proto ().op .extend (model_mask .op )
725+ model .net .Proto ().external_output .extend (model_mask .external_output )
726+ model .net .Proto ().external_input .extend (mask_external_input )
727+
728+ net , init_net = convert_to_pb (args , model .net .Proto (), blobs , input_net )
759729
760- for (pb_net , pb_init_net ) in models_pb .values ():
761- _save_models (pb_net , pb_init_net , args )
730+ _save_models (net , init_net , args )
762731
763732 if args .test_img is not None :
764- verify_model (args , models_pb , args .test_img )
733+ verify_model (args , net , init_net , args .test_img )
765734
766735if __name__ == '__main__' :
767736 main ()
0 commit comments