Skip to content
This repository was archived by the owner on Nov 21, 2023. It is now read-only.

Commit 8c49e43

Browse files
committed
Support export of CollectAndDistributeFpnRpnProposalsOp
1 parent 22636d4 commit 8c49e43

File tree

1 file changed

+83
-26
lines changed

1 file changed

+83
-26
lines changed

tools/convert_pkl_to_pb.py

Lines changed: 83 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,41 @@ def unscope_name(name):
119119

120120

121121
def reset_names(names):
122-
for i in range(0, len(names)):
122+
for i in range(len(names)):
123123
names[i] = unscope_name(names[i])
124124

125125

126+
def convert_collect_and_distribute(
127+
op, blobs,
128+
roi_canonical_scale,
129+
roi_canonical_level,
130+
roi_max_level,
131+
roi_min_level,
132+
rpn_max_level,
133+
rpn_min_level,
134+
rpn_post_nms_topN,
135+
):
136+
print('Converting CollectAndDistributeFpnRpnProposals'
137+
' Python -> C++:\n{}'.format(op))
138+
assert op.name.startswith('CollectAndDistributeFpnRpnProposalsOp'), \
139+
'Not valid CollectAndDistributeFpnRpnProposalsOp'
140+
141+
inputs = [x for x in op.input]
142+
ret = core.CreateOperator(
143+
'CollectAndDistributeFpnRpnProposals',
144+
inputs,
145+
list(op.output),
146+
roi_canonical_scale=roi_canonical_scale,
147+
roi_canonical_level=roi_canonical_level,
148+
roi_max_level=roi_max_level,
149+
roi_min_level=roi_min_level,
150+
rpn_max_level=rpn_max_level,
151+
rpn_min_level=rpn_min_level,
152+
rpn_post_nms_topN=rpn_post_nms_topN,
153+
)
154+
return ret
155+
156+
126157
def convert_gen_proposals(
127158
op, blobs,
128159
rpn_pre_nms_topN,
@@ -131,19 +162,19 @@ def convert_gen_proposals(
131162
rpn_min_size,
132163
):
133164
print('Converting GenerateProposals Python -> C++:\n{}'.format(op))
134-
assert op.name.startswith("GenerateProposalsOp"), "Not valid GenerateProposalsOp"
165+
assert op.name.startswith('GenerateProposalsOp'), 'Not valid GenerateProposalsOp'
135166

136-
spatial_scale = mutils.get_op_arg_valf(op, "spatial_scale", None)
167+
spatial_scale = mutils.get_op_arg_valf(op, 'spatial_scale', None)
137168
assert spatial_scale is not None
138169

139170
inputs = [x for x in op.input]
140-
anchor_name = "anchor"
171+
anchor_name = 'anchor'
141172
inputs.append(anchor_name)
142173
blobs[anchor_name] = get_anchors(spatial_scale)
143174
print('anchors {}'.format(blobs[anchor_name]))
144175

145176
ret = core.CreateOperator(
146-
"GenerateProposals",
177+
'GenerateProposals',
147178
inputs,
148179
list(op.output),
149180
spatial_scale=spatial_scale,
@@ -153,7 +184,6 @@ def convert_gen_proposals(
153184
min_size=rpn_min_size,
154185
correct_transform_coords=True,
155186
)
156-
157187
return ret, anchor_name
158188

159189

@@ -183,25 +213,48 @@ def convert_op_name(op):
183213
reset_names(op.output)
184214
return [op]
185215

186-
@op_filter(type="Python", inputs=['rpn_cls_probs', 'rpn_bbox_pred', 'im_info'])
187-
def convert_gen_proposal(op_in):
188-
gen_proposals_op, ext_input = convert_gen_proposals(
189-
op_in, blobs,
190-
rpn_min_size=float(cfg.TEST.RPN_MIN_SIZE),
191-
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
192-
rpn_pre_nms_topN=cfg.TEST.RPN_PRE_NMS_TOP_N,
193-
rpn_nms_thres=cfg.TEST.RPN_NMS_THRESH,
194-
)
195-
net.external_input.extend([ext_input])
196-
return [gen_proposals_op]
216+
@op_filter()
217+
def convert_python(op):
218+
if op.type == 'Python':
219+
if op.name.startswith('GenerateProposalsOp'):
220+
gen_proposals_op, ext_input = convert_gen_proposals(
221+
op, blobs,
222+
rpn_min_size=float(cfg.TEST.RPN_MIN_SIZE),
223+
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
224+
rpn_pre_nms_topN=cfg.TEST.RPN_PRE_NMS_TOP_N,
225+
rpn_nms_thres=cfg.TEST.RPN_NMS_THRESH,
226+
)
227+
net.external_input.extend([ext_input])
228+
return [gen_proposals_op]
229+
elif op.name.startswith('CollectAndDistributeFpnRpnProposalsOp'):
230+
collect_dist_op = convert_collect_and_distribute(
231+
op, blobs,
232+
roi_canonical_scale=cfg.FPN.ROI_CANONICAL_SCALE,
233+
roi_canonical_level=cfg.FPN.ROI_CANONICAL_LEVEL,
234+
roi_max_level=cfg.FPN.ROI_MAX_LEVEL,
235+
roi_min_level=cfg.FPN.ROI_MIN_LEVEL,
236+
rpn_max_level=cfg.FPN.RPN_MAX_LEVEL,
237+
rpn_min_level=cfg.FPN.RPN_MIN_LEVEL,
238+
rpn_post_nms_topN=cfg.TEST.RPN_POST_NMS_TOP_N,
239+
)
240+
return [collect_dist_op]
241+
else:
242+
raise ValueError('Failed to convert Python op {}'.format(
243+
op.name))
244+
return [op]
197245

198-
@op_filter(input_has='rois')
246+
@op_filter()
199247
def convert_rpn_rois(op):
200-
for j in range(0, len(op.input)):
248+
for j in range(len(op.input)):
201249
if op.input[j] == 'rois':
202250
print('Converting op {} input name: rois -> rpn_rois:\n{}'.format(
203251
op.type, op))
204252
op.input[j] = 'rpn_rois'
253+
for j in range(len(op.output)):
254+
if op.output[j] == 'rois':
255+
print('Converting op {} output name: rois -> rpn_rois:\n{}'.format(
256+
op.type, op))
257+
op.output[j] = 'rpn_rois'
205258
return [op]
206259

207260
@op_filter(type_in=['StopGradient', 'Alias'])
@@ -211,7 +264,7 @@ def convert_remove_op(op):
211264

212265
convert_op_in_proto(net, convert_op_name)
213266
convert_op_in_proto(net, [
214-
convert_gen_proposal, convert_rpn_rois, convert_remove_op
267+
convert_python, convert_rpn_rois, convert_remove_op
215268
])
216269

217270
reset_names(net.external_input)
@@ -267,6 +320,7 @@ def convert_model_gpu(args, net, init_net):
267320
cdo_cpu = mutils.get_device_option_cpu()
268321

269322
CPU_OPS = [
323+
["CollectAndDistributeFpnRpnProposals", None],
270324
["GenerateProposals", None],
271325
["BBoxTransform", None],
272326
["BoxWithNMSLimit", None],
@@ -457,7 +511,7 @@ def run_model_pb(args, net, init_net, im, check_blobs):
457511
)
458512

459513
try:
460-
workspace.RunNet(net.Proto().name)
514+
workspace.RunNet(net)
461515
scores = workspace.FetchBlob('score_nms')
462516
classids = workspace.FetchBlob('class_nms')
463517
boxes = workspace.FetchBlob('bbox_nms')
@@ -515,13 +569,16 @@ def main():
515569
merge_cfg_from_list(args.opts)
516570
cfg.NUM_GPUS = 1
517571
assert_and_infer_cfg()
518-
logger.info('Conerting model with config:')
572+
logger.info('Converting model with config:')
519573
logger.info(pprint.pformat(cfg))
520574

521-
assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
522-
assert not cfg.MODEL.MASK_ON, "Mask model not supported."
523-
assert not cfg.FPN.FPN_ON, "FPN not supported."
524-
assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."
575+
# script will stop when it can't find an operator rather
576+
# than stopping based on these flags
577+
#
578+
# assert not cfg.MODEL.KEYPOINTS_ON, "Keypoint model not supported."
579+
# assert not cfg.MODEL.MASK_ON, "Mask model not supported."
580+
# assert not cfg.FPN.FPN_ON, "FPN not supported."
581+
# assert not cfg.RETINANET.RETINANET_ON, "RetinaNet model not supported."
525582

526583
# load model from cfg
527584
model, blobs = load_model(args)

0 commit comments

Comments
 (0)