diff --git a/detectron/utils/model_convert_utils.py b/detectron/utils/model_convert_utils.py index 17752db2b..e7f624d9e 100644 --- a/detectron/utils/model_convert_utils.py +++ b/detectron/utils/model_convert_utils.py @@ -315,11 +315,11 @@ def gen_init_net_from_blobs(blobs, blobs_to_use=None, excluded_blobs=None): blobs_to_use = [x for x in blobs_to_use if x not in excluded_blobs] for name in blobs_to_use: blob = blobs[name] - if isinstance(blob, str): + if isinstance(blob, np.ndarray): + add_tensor(ret, name, blob) + else: print('Blob {} with type {} is not supported in generating init net,' ' skipped.'.format(name, type(blob))) - continue - add_tensor(ret, name, blob) return ret diff --git a/tools/convert_pkl_to_pb.py b/tools/convert_pkl_to_pb.py index a553444bc..7eb19ac28 100644 --- a/tools/convert_pkl_to_pb.py +++ b/tools/convert_pkl_to_pb.py @@ -52,9 +52,13 @@ from detectron.utils.model_convert_utils import op_filter import detectron.utils.blob as blob_utils import detectron.core.test_engine as test_engine +import detectron.core.test as test import detectron.utils.c2 as c2_utils import detectron.utils.model_convert_utils as mutils import detectron.utils.vis as vis_utils +import detectron.utils.blob as blob_utils +import detectron.utils.keypoints as keypoint_utils +import pycocotools.mask as mask_utils c2_utils.import_contrib_ops() c2_utils.import_detectron_ops() @@ -334,8 +338,8 @@ def add_bbox_ops(args, net, blobs): new_ops.extend([op_nms]) new_external_outputs.extend(['score_nms', 'bbox_nms', 'class_nms']) - net.Proto().op.extend(new_ops) - net.Proto().external_output.extend(new_external_outputs) + net.op.extend(new_ops) + net.external_output.extend(new_external_outputs) def convert_model_gpu(args, net, init_net): @@ -392,23 +396,23 @@ def gen_init_net(net, blobs, empty_blobs): def _save_image_graphs(args, all_net, all_init_net): print('Saving model graph...') mutils.save_graph( - all_net.Proto(), os.path.join(args.out_dir, "model_def.png"), + all_net.Proto(), os.path.join(args.out_dir, all_net.Proto().name + '.png'), op_only=False) print('Model def image saved to {}.'.format(args.out_dir)) def _save_models(all_net, all_init_net, args): print('Writing converted model to {}...'.format(args.out_dir)) - fname = "model" + fname = all_net.Proto().name if not os.path.exists(args.out_dir): os.makedirs(args.out_dir) - with open(os.path.join(args.out_dir, fname + '.pb'), 'w') as f: + with open(os.path.join(args.out_dir, fname + '.pb'), 'wb') as f: f.write(all_net.Proto().SerializeToString()) with open(os.path.join(args.out_dir, fname + '.pbtxt'), 'w') as f: f.write(str(all_net.Proto())) - with open(os.path.join(args.out_dir, fname + '_init.pb'), 'w') as f: + with open(os.path.join(args.out_dir, fname + '_init.pb'), 'wb') as f: f.write(all_init_net.Proto().SerializeToString()) _save_image_graphs(args, all_net, all_init_net) @@ -457,13 +461,14 @@ def run_model_cfg(args, im, check_blobs): cls_boxes, cls_segms, cls_keyps = test_engine.im_detect_all( model, im, None, None, ) - - boxes, segms, keypoints, classes = vis_utils.convert_from_cls_format( + boxes, segms, keypoints, classids = vis_utils.convert_from_cls_format( cls_boxes, cls_segms, cls_keyps) + segms = mask_utils.decode(segms) if segms else None + # sort the results based on score for comparision - boxes, segms, keypoints, classes = _sort_results( - boxes, segms, keypoints, classes) + boxes, segms, keypoints, classids = _sort_results( + boxes, segms, keypoints, classids) # write final results back to workspace def _ornone(res): @@ -472,12 +477,16 @@ def _ornone(res): workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes)) workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms)) workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints)) - workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classes)) + workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids)) # get result blobs with c2_utils.NamedCudaScope(0): ret = _get_result_blobs(check_blobs) + print('result_boxes', _ornone(boxes)) + print('result_segms', _ornone(segms)) + print('result_keypoints', _ornone(keypoints)) + print('result_classids', _ornone(classids)) return ret @@ -519,7 +528,6 @@ def run_model_pb(args, net, init_net, im, check_blobs): mutils.create_input_blobs_for_net(net.Proto()) workspace.CreateNet(net) - # input_blobs, _ = core_test._get_blobs(im, None) input_blobs = _prepare_blobs( im, cfg.PIXEL_MEANS, @@ -538,36 +546,76 @@ def run_model_pb(args, net, init_net, im, check_blobs): try: workspace.RunNet(net) - scores = workspace.FetchBlob('score_nms') - classids = workspace.FetchBlob('class_nms') - boxes = workspace.FetchBlob('bbox_nms') + scores = workspace.FetchBlob(core.ScopedName('score_nms')) + classids = workspace.FetchBlob(core.ScopedName('class_nms')) + boxes = workspace.FetchBlob(core.ScopedName('bbox_nms')) except Exception as e: - print('Running pb model failed.\n{}'.format(e)) - # may not detect anything at all + logger.warn('Running pb model failed.\n{}'.format(e)) R = 0 scores = np.zeros((R,), dtype=np.float32) boxes = np.zeros((R, 4), dtype=np.float32) classids = np.zeros((R,), dtype=np.float32) + cls_segms, cls_keyps = None, None + + if net.BlobIsDefined(core.ScopedName('kps_score')): + pred_heatmaps = workspace.FetchBlob(core.ScopedName('kps_score')).squeeze() + # In case of 1 + if pred_heatmaps.ndim == 3: + pred_heatmaps = np.expand_dims(pred_heatmaps, axis=0) + xy_preds = keypoint_utils.heatmaps_to_keypoints(pred_heatmaps, boxes) + cls_keyps = [[] for _ in range(cfg.MODEL.NUM_CLASSES)] + cls_keyps[1] = [xy_preds[i] for i in range(xy_preds.shape[0])] + else: + logger.info('Keypoint blob is not defined') + + if net.BlobIsDefined(core.ScopedName('mask_fcn_probs')): + # Fetch masks + pred_masks = workspace.FetchBlob(core.ScopedName('mask_fcn_probs')).squeeze() + M = cfg.MRCNN.RESOLUTION + if cfg.MRCNN.CLS_SPECIFIC_MASK: + pred_masks = pred_masks.reshape([-1, cfg.MODEL.NUM_CLASSES, M, M]) + else: + pred_masks = pred_masks.reshape([-1, 1, M, M]) + cls_boxes = [np.empty(list(classids).count(i)) for i in range(cfg.MODEL.NUM_CLASSES)] + cls_segms = test.segm_results(cls_boxes, pred_masks, boxes, im.shape[0], im.shape[1]) + else: + logger.info('Mask blob is not defined') + boxes = np.column_stack((boxes, scores)) + _, segms, keypoints, _ = vis_utils.convert_from_cls_format([], cls_segms, cls_keyps) + segms = mask_utils.decode(segms) if segms else None + # sort the results based on score for comparision - boxes, _, _, classids = _sort_results( - boxes, None, None, classids) + boxes, segms, keypoints, classids = _sort_results( + boxes, segms, keypoints, classids) # write final result back to workspace - workspace.FeedBlob('result_boxes', boxes) - workspace.FeedBlob('result_classids', classids) + def _ornone(res): + return np.array(res) if res is not None else np.array([], dtype=np.float32) + workspace.FeedBlob(core.ScopedName('result_boxes'), _ornone(boxes)) + workspace.FeedBlob(core.ScopedName('result_classids'), _ornone(classids)) + workspace.FeedBlob(core.ScopedName('result_segms'), _ornone(segms)) + workspace.FeedBlob(core.ScopedName('result_keypoints'), _ornone(keypoints)) ret = _get_result_blobs(check_blobs) + print('result_boxes', _ornone(boxes)) + print('result_segms', _ornone(segms)) + print('result_keypoints', _ornone(keypoints)) + print('result_classids', _ornone(classids)) return ret -def verify_model(args, model_pb, test_img_file): - check_blobs = [ - "result_boxes", "result_classids", # result - ] +def verify_model(args, net, init_net, test_img_file): + check_blobs = ['result_boxes', 'result_classids'] + + if cfg.MODEL.MASK_ON: + check_blobs.append('result_segms') + + if cfg.MODEL.KEYPOINTS_ON: + check_blobs.append('result_keypoints') print('Loading test file {}...'.format(test_img_file)) test_img = cv2.imread(test_img_file) @@ -577,13 +625,49 @@ def _run_cfg_func(im, blobs): return run_model_cfg(args, im, check_blobs) def _run_pb_func(im, blobs): - return run_model_pb(args, model_pb[0], model_pb[1], im, check_blobs) + return run_model_pb(args, net, init_net, im, check_blobs) print('Checking models...') assert mutils.compare_model( _run_cfg_func, _run_pb_func, test_img, check_blobs) +def convert_to_pb(args, net, blobs, input_blobs): + pb_net = core.Net('') + pb_net.Proto().op.extend(copy.deepcopy(net.op)) + + pb_net.Proto().external_input.extend( + copy.deepcopy(net.external_input)) + pb_net.Proto().external_output.extend( + copy.deepcopy(net.external_output)) + pb_net.Proto().type = args.net_execution_type + pb_net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4 + + # Reset the device_option, change to unscope name and replace python operators + convert_net(args, pb_net.Proto(), blobs) + + # add operators for bbox + add_bbox_ops(args, pb_net.Proto(), blobs) + + if args.fuse_af: + print('Fusing affine channel...') + pb_net, blobs = mutils.fuse_net_affine(pb_net, blobs) + + if args.use_nnpack: + mutils.update_mobile_engines(pb_net.Proto()) + + # generate init net + pb_init_net = gen_init_net(pb_net, blobs, input_blobs) + + if args.device == 'gpu': + [pb_net, pb_init_net] = convert_model_gpu(args, pb_net, pb_init_net) + + pb_net.Proto().name = args.net_name + '_net' + pb_init_net.Proto().name = args.net_name + '_net_init' + + return pb_net, pb_init_net + + def main(): workspace.GlobalInit(['caffe2', '--caffe2_log_level=0']) args = parse_args() @@ -598,55 +682,55 @@ def main(): logger.info('Converting model with config:') logger.info(pprint.pformat(cfg)) - # script will stop when it can't find an operator rather - # than stopping based on these flags - # - # assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported." - # assert not cfg.MODEL.MASK_ON, "Mask model not supported." - # assert not cfg.FPN.FPN_ON, "FPN not supported." - # assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported." - # load model from cfg model, blobs = load_model(args) - net = core.Net('') - net.Proto().op.extend(copy.deepcopy(model.net.Proto().op)) - net.Proto().external_input.extend( - copy.deepcopy(model.net.Proto().external_input)) - net.Proto().external_output.extend( - copy.deepcopy(model.net.Proto().external_output)) - net.Proto().type = args.net_execution_type - net.Proto().num_workers = 1 if args.net_execution_type == 'simple' else 4 + input_net = ['data', 'im_info'] - # Reset the device_option, change to unscope name and replace python operators - convert_net(args, net.Proto(), blobs) + if cfg.MODEL.KEYPOINTS_ON: + model_kps = model.keypoint_net.Proto() - # add operators for bbox - add_bbox_ops(args, net, blobs) + # Connect rois blobs + for op in model_kps.op: + for i, input_name in enumerate(op.input): + op.input[i] = input_name.replace("keypoint_rois", "rois") - if args.fuse_af: - print('Fusing affine channel...') - net, blobs = mutils.fuse_net_affine( - net, blobs) + # Remove external input defined in main net + kps_external_input = [] + for i in model_kps.external_input: + if not model.net.BlobIsDefined(i) and \ + not "keypoint_rois" in i: + kps_external_input.append(i) - if args.use_nnpack: - mutils.update_mobile_engines(net.Proto()) + model.net.Proto().op.extend(model_kps.op) + model.net.Proto().external_output.extend(model_kps.external_output) + model.net.Proto().external_input.extend(kps_external_input) - # generate init net - empty_blobs = ['data', 'im_info'] - init_net = gen_init_net(net, blobs, empty_blobs) + if cfg.MODEL.MASK_ON: + model_mask = model.mask_net.Proto() - if args.device == 'gpu': - [net, init_net] = convert_model_gpu(args, net, init_net) + # Connect rois blobs + for op in model_mask.op: + for i, input_name in enumerate(op.input): + op.input[i] = input_name.replace("mask_rois", "rois") - net.Proto().name = args.net_name - init_net.Proto().name = args.net_name + "_init" + # Remove external input defined in main net + mask_external_input = [] + for i in model_mask.external_input: + if not model.net.BlobIsDefined(i) and \ + not "mask_rois" in i: + mask_external_input.append(i) - if args.test_img is not None: - verify_model(args, [net, init_net], args.test_img) + model.net.Proto().op.extend(model_mask.op) + model.net.Proto().external_output.extend(model_mask.external_output) + model.net.Proto().external_input.extend(mask_external_input) + + net, init_net = convert_to_pb(args, model.net.Proto(), blobs, input_net) _save_models(net, init_net, args) + if args.test_img is not None: + verify_model(args, net, init_net, args.test_img) if __name__ == '__main__': main()