Skip to content

Commit d9bada9

Browse files
committed
dnn: EfficientDet
1 parent d5e8792 commit d9bada9

File tree

5 files changed

+302
-8
lines changed

5 files changed

+302
-8
lines changed

modules/dnn/perf/perf_net.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,17 @@ PERF_TEST_P_(DNNTestNetwork, Inception_v2_Faster_RCNN)
235235
Mat(cv::Size(800, 600), CV_32FC3));
236236
}
237237

238+
PERF_TEST_P_(DNNTestNetwork, EfficientDet)
239+
{
240+
if (backend == DNN_BACKEND_HALIDE || target != DNN_TARGET_CPU)
241+
throw SkipTestException("");
242+
Mat sample = imread(findDataFile("dnn/dog416.png"));
243+
resize(sample, sample, Size(512, 512));
244+
Mat inp;
245+
sample.convertTo(inp, CV_32FC3, 1.0/255);
246+
processNet("dnn/efficientdet-d0.pb", "dnn/efficientdet-d0.pbtxt", "", inp);
247+
}
248+
238249
INSTANTIATE_TEST_CASE_P(/*nothing*/, DNNTestNetwork, dnnBackendsAndTargets());
239250

240251
} // namespace

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1542,22 +1542,32 @@ void TFImporter::populateNet(Net dstNet)
15421542

15431543
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
15441544
}
1545-
else if (type == "Mul")
1545+
else if (type == "Mul" || type == "RealDiv")
15461546
{
1547-
bool haveConst = false;
1548-
for(int ii = 0; !haveConst && ii < layer.input_size(); ++ii)
1547+
int constId = -1;
1548+
for(int ii = 0; ii < layer.input_size(); ++ii)
15491549
{
15501550
Pin input = parsePin(layer.input(ii));
1551-
haveConst = value_id.find(input.name) != value_id.end();
1551+
if (value_id.find(input.name) != value_id.end())
1552+
{
1553+
constId = ii;
1554+
break;
1555+
}
15521556
}
1553-
CV_Assert(!haveConst || layer.input_size() == 2);
1557+
CV_Assert((constId != -1) || (layer.input_size() == 2));
15541558

1555-
if (haveConst)
1559+
if (constId != -1)
15561560
{
15571561
// Multiplication by constant.
15581562
CV_Assert(layer.input_size() == 2);
15591563
Mat scaleMat = getTensorContent(getConstBlob(layer, value_id));
15601564
CV_Assert(scaleMat.type() == CV_32FC1);
1565+
if (type == "RealDiv")
1566+
{
1567+
if (constId == 0)
1568+
CV_Error(Error::StsNotImplemented, "Division of constant over variable");
1569+
scaleMat = 1.0f / scaleMat;
1570+
}
15611571

15621572
int id;
15631573
if (scaleMat.total() == 1) // is a scalar.
@@ -1659,11 +1669,15 @@ void TFImporter::populateNet(Net dstNet)
16591669
int id;
16601670
if (equalInpShapes || netInputShapes.empty())
16611671
{
1662-
layerParams.set("operation", "prod");
1672+
layerParams.set("operation", type == "RealDiv" ? "div" : "prod");
16631673
id = dstNet.addLayer(name, "Eltwise", layerParams);
16641674
}
16651675
else
1676+
{
1677+
if (type == "RealDiv")
1678+
CV_Error(Error::StsNotImplemented, "Division of non equal tensors");
16661679
id = dstNet.addLayer(name, "Scale", layerParams);
1680+
}
16671681

16681682
layer_id[name] = id;
16691683

modules/dnn/test/test_tf_importer.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,4 +1123,37 @@ TEST_P(Test_TensorFlow_nets, Mask_RCNN)
11231123
expectNoFallbacks(net);
11241124
}
11251125

1126+
TEST_P(Test_TensorFlow_nets, EfficientDet)
1127+
{
1128+
if (target != DNN_TARGET_CPU)
1129+
{
1130+
if (target == DNN_TARGET_OPENCL_FP16) applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
1131+
if (target == DNN_TARGET_OPENCL) applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL);
1132+
if (target == DNN_TARGET_MYRIAD) applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD);
1133+
}
1134+
checkBackend();
1135+
std::string proto = findDataFile("dnn/efficientdet-d0.pbtxt");
1136+
std::string model = findDataFile("dnn/efficientdet-d0.pb");
1137+
1138+
Net net = readNetFromTensorflow(model, proto);
1139+
Mat img = imread(findDataFile("dnn/dog416.png"));
1140+
Mat blob = blobFromImage(img, 1.0/255, Size(512, 512), Scalar(123.675, 116.28, 103.53));
1141+
1142+
net.setPreferableBackend(backend);
1143+
net.setPreferableTarget(target);
1144+
net.setInput(blob);
1145+
// Output has shape 1x1xNx7 where N - number of detections.
1146+
// An every detection is a vector of values [id, classId, confidence, left, top, right, bottom]
1147+
Mat out = net.forward();
1148+
1149+
// References are from test for TensorFlow model.
1150+
Mat ref = (Mat_<float>(3, 7) << 0, 1, 0.8437444, 0.153996080160141, 0.20534580945968628, 0.7463544607162476, 0.7414066195487976,
1151+
0, 17, 0.8245924, 0.16657517850399017, 0.3996818959712982, 0.4111558794975281, 0.9306337833404541,
1152+
0, 7, 0.8039304, 0.6118435263633728, 0.13175517320632935, 0.9065558314323425, 0.2943994700908661);
1153+
double scoreDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 4e-3 : 1e-5;
1154+
double iouDiff = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 2e-3 : 1e-4;
1155+
normAssertDetections(ref, out, "", 0.5, scoreDiff, iouDiff);
1156+
expectNoFallbacksFromIE(net);
1157+
}
1158+
11261159
}

samples/dnn/tf_text_graph_common.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def parseTextGraph(filePath):
269269
def removeIdentity(graph_def):
270270
identities = {}
271271
for node in graph_def.node:
272-
if node.op == 'Identity':
272+
if node.op == 'Identity' or node.op == 'IdentityN':
273273
identities[node.name] = node.input[0]
274274
graph_def.node.remove(node)
275275

Lines changed: 236 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,236 @@
1+
# This file is a part of OpenCV project.
2+
# It is a subject to the license terms in the LICENSE file found in the top-level directory
3+
# of this distribution and at http://opencv.org/license.html.
4+
#
5+
# Copyright (C) 2020, Intel Corporation, all rights reserved.
6+
# Third party copyrights are property of their respective owners.
7+
#
8+
# Use this script to get the text graph representation (.pbtxt) of EfficientDet
9+
# deep learning network trained in https://github.com/google/automl.
10+
# Then you can import it with a binary frozen graph (.pb) using readNetFromTensorflow() function.
11+
# See details and examples on the following wiki page: https://github.com/opencv/opencv/wiki/TensorFlow-Object-Detection-API
12+
import argparse
13+
import re
14+
from math import sqrt
15+
from tf_text_graph_common import *
16+
17+
18+
class AnchorGenerator:
19+
def __init__(self, min_level, aspect_ratios, num_scales, anchor_scale):
20+
self.min_level = min_level
21+
self.aspect_ratios = aspect_ratios
22+
self.anchor_scale = anchor_scale
23+
self.scales = [2**(float(s) / num_scales) for s in range(num_scales)]
24+
25+
def get(self, layer_id):
26+
widths = []
27+
heights = []
28+
for s in self.scales:
29+
for a in self.aspect_ratios:
30+
base_anchor_size = 2**(self.min_level + layer_id) * self.anchor_scale
31+
heights.append(base_anchor_size * s * a[1])
32+
widths.append(base_anchor_size * s * a[0])
33+
return widths, heights
34+
35+
36+
def createGraph(modelPath, outputPath, min_level, aspect_ratios, num_scales,
37+
anchor_scale, num_classes, image_width, image_height):
38+
print('Min level: %d' % min_level)
39+
print('Anchor scale: %f' % anchor_scale)
40+
print('Num scales: %d' % num_scales)
41+
print('Aspect ratios: %s' % str(aspect_ratios))
42+
print('Number of classes: %d' % num_classes)
43+
print('Input image size: %dx%d' % (image_width, image_height))
44+
45+
# Read the graph.
46+
_inpNames = ['image_arrays']
47+
outNames = ['detections']
48+
49+
writeTextGraph(modelPath, outputPath, outNames)
50+
graph_def = parseTextGraph(outputPath)
51+
52+
def getUnconnectedNodes():
53+
unconnected = []
54+
for node in graph_def.node:
55+
if node.op == 'Const':
56+
continue
57+
unconnected.append(node.name)
58+
for inp in node.input:
59+
if inp in unconnected:
60+
unconnected.remove(inp)
61+
return unconnected
62+
63+
64+
nodesToKeep = ['truediv'] # Keep preprocessing nodes
65+
66+
removeIdentity(graph_def)
67+
68+
scopesToKeep = ('image_arrays', 'efficientnet', 'resample_p6', 'resample_p7',
69+
'fpn_cells', 'class_net', 'box_net', 'Reshape', 'concat')
70+
71+
addConstNode('scale_w', [2.0], graph_def)
72+
addConstNode('scale_h', [2.0], graph_def)
73+
nodesToKeep += ['scale_w', 'scale_h']
74+
75+
for node in graph_def.node:
76+
if re.match('efficientnet-(.*)/blocks_\d+/se/mul_1', node.name):
77+
node.input[0], node.input[1] = node.input[1], node.input[0]
78+
79+
if re.match('fpn_cells/cell_\d+/fnode\d+/resample(.*)/nearest_upsampling/Reshape_1$', node.name):
80+
node.op = 'ResizeNearestNeighbor'
81+
node.input[1] = 'scale_w'
82+
node.input.append('scale_h')
83+
84+
for inpNode in graph_def.node:
85+
if inpNode.name == node.name[:node.name.rfind('_')]:
86+
node.input[0] = inpNode.input[0]
87+
88+
if re.match('box_net/box-predict(_\d)*/separable_conv2d$', node.name):
89+
node.addAttr('loc_pred_transposed', True)
90+
91+
# Replace RealDiv to Mul with inversed scale for compatibility
92+
if node.op == 'RealDiv':
93+
for inpNode in graph_def.node:
94+
if inpNode.name != node.input[1] or not 'value' in inpNode.attr:
95+
continue
96+
97+
tensor = inpNode.attr['value']['tensor'][0]
98+
if not 'float_val' in tensor:
99+
continue
100+
scale = float(inpNode.attr['value']['tensor'][0]['float_val'][0])
101+
102+
addConstNode(inpNode.name + '/inv', [1.0 / scale], graph_def)
103+
nodesToKeep.append(inpNode.name + '/inv')
104+
node.input[1] = inpNode.name + '/inv'
105+
node.op = 'Mul'
106+
break
107+
108+
109+
def to_remove(name, op):
110+
if name in nodesToKeep:
111+
return False
112+
return op == 'Const' or not name.startswith(scopesToKeep)
113+
114+
removeUnusedNodesAndAttrs(to_remove, graph_def)
115+
116+
# Attach unconnected preprocessing
117+
assert(graph_def.node[1].name == 'truediv' and graph_def.node[1].op == 'RealDiv')
118+
graph_def.node[1].input.insert(0, 'image_arrays')
119+
graph_def.node[2].input.insert(0, 'truediv')
120+
121+
priors_generator = AnchorGenerator(min_level, aspect_ratios, num_scales, anchor_scale)
122+
priorBoxes = []
123+
for i in range(5):
124+
inpName = ''
125+
for node in graph_def.node:
126+
if node.name == 'Reshape_%d' % (i * 2 + 1):
127+
inpName = node.input[0]
128+
break
129+
130+
priorBox = NodeDef()
131+
priorBox.name = 'PriorBox_%d' % i
132+
priorBox.op = 'PriorBox'
133+
priorBox.input.append(inpName)
134+
priorBox.input.append(graph_def.node[0].name) # image_tensor
135+
136+
priorBox.addAttr('flip', False)
137+
priorBox.addAttr('clip', False)
138+
139+
widths, heights = priors_generator.get(i)
140+
141+
priorBox.addAttr('width', widths)
142+
priorBox.addAttr('height', heights)
143+
priorBox.addAttr('variance', [1.0, 1.0, 1.0, 1.0])
144+
145+
graph_def.node.extend([priorBox])
146+
priorBoxes.append(priorBox.name)
147+
148+
addConstNode('concat/axis_flatten', [-1], graph_def)
149+
150+
def addConcatNode(name, inputs, axisNodeName):
151+
concat = NodeDef()
152+
concat.name = name
153+
concat.op = 'ConcatV2'
154+
for inp in inputs:
155+
concat.input.append(inp)
156+
concat.input.append(axisNodeName)
157+
graph_def.node.extend([concat])
158+
159+
addConcatNode('PriorBox/concat', priorBoxes, 'concat/axis_flatten')
160+
161+
sigmoid = NodeDef()
162+
sigmoid.name = 'concat/sigmoid'
163+
sigmoid.op = 'Sigmoid'
164+
sigmoid.input.append('concat')
165+
graph_def.node.extend([sigmoid])
166+
167+
addFlatten(sigmoid.name, sigmoid.name + '/Flatten', graph_def)
168+
addFlatten('concat_1', 'concat_1/Flatten', graph_def)
169+
170+
detectionOut = NodeDef()
171+
detectionOut.name = 'detection_out'
172+
detectionOut.op = 'DetectionOutput'
173+
174+
detectionOut.input.append('concat_1/Flatten')
175+
detectionOut.input.append(sigmoid.name + '/Flatten')
176+
detectionOut.input.append('PriorBox/concat')
177+
178+
detectionOut.addAttr('num_classes', num_classes)
179+
detectionOut.addAttr('share_location', True)
180+
detectionOut.addAttr('background_label_id', num_classes + 1)
181+
detectionOut.addAttr('nms_threshold', 0.6)
182+
detectionOut.addAttr('confidence_threshold', 0.2)
183+
detectionOut.addAttr('top_k', 100)
184+
detectionOut.addAttr('keep_top_k', 100)
185+
detectionOut.addAttr('code_type', "CENTER_SIZE")
186+
graph_def.node.extend([detectionOut])
187+
188+
graph_def.node[0].attr['shape'] = {
189+
'shape': {
190+
'dim': [
191+
{'size': -1},
192+
{'size': image_height},
193+
{'size': image_width},
194+
{'size': 3}
195+
]
196+
}
197+
}
198+
199+
while True:
200+
unconnectedNodes = getUnconnectedNodes()
201+
unconnectedNodes.remove(detectionOut.name)
202+
if not unconnectedNodes:
203+
break
204+
205+
for name in unconnectedNodes:
206+
for i in range(len(graph_def.node)):
207+
if graph_def.node[i].name == name:
208+
del graph_def.node[i]
209+
break
210+
211+
# Save as text
212+
graph_def.save(outputPath)
213+
214+
215+
if __name__ == "__main__":
216+
parser = argparse.ArgumentParser(description='Run this script to get a text graph of '
217+
'SSD model from TensorFlow Object Detection API. '
218+
'Then pass it with .pb file to cv::dnn::readNetFromTensorflow function.')
219+
parser.add_argument('--input', required=True, help='Path to frozen TensorFlow graph.')
220+
parser.add_argument('--output', required=True, help='Path to output text graph.')
221+
parser.add_argument('--min_level', default=3, type=int, help='Parameter from training config')
222+
parser.add_argument('--num_scales', default=3, type=int, help='Parameter from training config')
223+
parser.add_argument('--anchor_scale', default=4.0, type=float, help='Parameter from training config')
224+
parser.add_argument('--aspect_ratios', default=[1.0, 1.0, 1.4, 0.7, 0.7, 1.4],
225+
nargs='+', type=float, help='Parameter from training config')
226+
parser.add_argument('--num_classes', default=90, type=int, help='Number of classes to detect')
227+
parser.add_argument('--width', default=512, type=int, help='Network input width')
228+
parser.add_argument('--height', default=512, type=int, help='Network input height')
229+
args = parser.parse_args()
230+
231+
ar = args.aspect_ratios
232+
assert(len(ar) % 2 == 0)
233+
ar = list(zip(ar[::2], ar[1::2]))
234+
235+
createGraph(args.input, args.output, args.min_level, ar, args.num_scales,
236+
args.anchor_scale, args.num_classes, args.width, args.height)

0 commit comments

Comments
 (0)