Skip to content

Commit ad63d24

Browse files
authored
Merge pull request opencv#18096 from l-bat:update_onnx_importer
* Added ReduceSum to ONNX importer * Fix comments * Fix Mul
1 parent 3b5813c commit ad63d24

File tree

4 files changed

+124
-34
lines changed

4 files changed

+124
-34
lines changed

modules/dnn/src/layers/fully_connected_layer.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,6 @@ class FullyConnectedLayerImpl CV_FINAL : public InnerProductLayer
116116
CV_CheckEQ(inputs.size(), (size_t)2, "");
117117
numOutput = inputs[1].back();
118118
cAxis = inputs[0].size() - 1;
119-
CV_CheckEQ(numOutput, inputs[0][cAxis - 1], "");
120119
int dims = inputs[0].size();
121120
CV_CheckEQ(inputs[1].size(), (size_t)dims, "");
122121
CV_CheckGE(dims, 2, "");

modules/dnn/src/onnx/onnx_graph_simplifier.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,24 @@ class GatherCastSubgraph : public Subgraph
262262
}
263263
};
264264

265+
class ExpandSubgraph : public Subgraph
266+
{
267+
public:
268+
ExpandSubgraph()
269+
{
270+
int input = addNodeToMatch("");
271+
int values = addNodeToMatch("");
272+
int init = addNodeToMatch("ConstantOfShape", values);
273+
int coeff = addNodeToMatch("Constant");
274+
int mul = addNodeToMatch("Mul", init, coeff);
275+
int shape = addNodeToMatch("Constant");
276+
int condition = addNodeToMatch("Equal", shape, mul);
277+
int where = addNodeToMatch("Where", condition, init, addNodeToMatch("Constant"));
278+
addNodeToMatch("Expand", input, where);
279+
setFusedNode("Expand", input, shape);
280+
}
281+
};
282+
265283
class MulCastSubgraph : public Subgraph
266284
{
267285
public:
@@ -459,6 +477,7 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
459477
subgraphs.push_back(makePtr<NormalizeSubgraph3>());
460478
subgraphs.push_back(makePtr<BatchNormalizationSubgraph1>());
461479
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
480+
subgraphs.push_back(makePtr<ExpandSubgraph>());
462481

463482
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
464483
}

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 99 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -387,26 +387,42 @@ void ONNXImporter::populateNet(Net dstNet)
387387
layerParams.set("ceil_mode", layerParams.has("pad_mode"));
388388
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
389389
}
390-
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" || layer_type == "ReduceMean")
390+
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
391+
layer_type == "ReduceMean" || layer_type == "ReduceSum")
391392
{
392393
CV_Assert(node_proto.input_size() == 1);
393394
layerParams.type = "Pooling";
394-
layerParams.set("pool", layer_type == "GlobalMaxPool"? "MAX" : "AVE");
395+
String pool;
396+
if (layer_type == "GlobalMaxPool")
397+
pool = "MAX";
398+
else if (layer_type == "ReduceSum")
399+
pool = "SUM";
400+
else
401+
pool = "AVE";
402+
layerParams.set("pool", pool);
395403
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
396-
397-
if (layer_type == "ReduceMean")
404+
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
398405
{
399-
if (layerParams.get<int>("keepdims") == 0 || !layerParams.has("axes"))
400-
CV_Error(Error::StsNotImplemented, "Unsupported mode of ReduceMean operation.");
406+
if (!layerParams.has("axes"))
407+
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
401408

402409
MatShape inpShape = outShapes[node_proto.input(0)];
403410
DictValue axes = layerParams.get("axes");
411+
bool keepdims = layerParams.get<int>("keepdims");
412+
MatShape targetShape = inpShape;
413+
for (int i = 0; i < axes.size(); i++) {
414+
int axis = clamp(axes.get<int>(i), inpShape.size());
415+
if (keepdims) {
416+
targetShape[axis] = 1;
417+
} else {
418+
targetShape.erase(targetShape.begin() + axis);
419+
}
420+
}
421+
404422
if (inpShape.size() == 3 && axes.size() <= 2)
405423
{
406-
int axis = axes.get<int>(0);
424+
int axis = clamp(axes.get<int>(0), inpShape.size());
407425
CV_CheckNE(axis, 0, "");
408-
outShapes[layerParams.name] = inpShape;
409-
outShapes[layerParams.name][axis] = 1;
410426

411427
LayerParams reshapeLp;
412428
reshapeLp.name = layerParams.name + "/reshape";
@@ -426,13 +442,12 @@ void ONNXImporter::populateNet(Net dstNet)
426442
avgLp.name = layerParams.name + "/avg";
427443
avgLp.type = "Pooling";
428444
CV_Assert(layer_id.find(avgLp.name) == layer_id.end());
429-
avgLp.set("pool", "ave");
445+
avgLp.set("pool", pool);
430446
if (axes.size() == 2)
431447
{
432-
CV_CheckEQ(axes.get<int>(0), 1, "Unsupported ReduceMean mode");
433-
CV_CheckEQ(axes.get<int>(1), 2, "Unsupported ReduceMean mode");
448+
CV_CheckEQ(clamp(axes.get<int>(0), inpShape.size()), 1, ("Unsupported " + layer_type + " mode").c_str());
449+
CV_CheckEQ(clamp(axes.get<int>(1), inpShape.size()), 2, ("Unsupported " + layer_type + " mode").c_str());
434450
avgLp.set("global_pooling", true);
435-
outShapes[layerParams.name][axes.get<int>(1)] = 1;
436451
}
437452
else
438453
{
@@ -443,28 +458,33 @@ void ONNXImporter::populateNet(Net dstNet)
443458
node_proto.set_input(0, reshapeLp.name);
444459
node_proto.set_output(0, avgLp.name);
445460
addLayer(dstNet, avgLp, node_proto, layer_id, outShapes);
446-
447-
layerParams.type = "Flatten";
448-
layerParams.set("axis", 0);
449-
layerParams.set("end_axis", 1);
450-
451-
node_proto.set_input(0, avgLp.name);
452-
node_proto.set_output(0, layerParams.name);
453461
}
454462
else
455463
{
456464
if (inpShape.size() != 4 && inpShape.size() != 5)
457-
CV_Error(Error::StsNotImplemented, "Unsupported input shape of reduce_mean operation.");
465+
CV_Error(Error::StsNotImplemented, "Unsupported input shape of " + layer_type + " operation.");
458466

459467
CV_Assert(axes.size() <= inpShape.size() - 2);
460468
std::vector<int> kernel_size(inpShape.size() - 2, 1);
461469
for (int i = 0; i < axes.size(); i++) {
462-
int axis = axes.get<int>(i);
470+
int axis = clamp(axes.get<int>(i), inpShape.size());
463471
CV_Assert_N(axis >= 2 + i, axis < inpShape.size());
464472
kernel_size[axis - 2] = inpShape[axis];
465473
}
466-
layerParams.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
474+
LayerParams poolLp = layerParams;
475+
poolLp.name = layerParams.name + "/avg";
476+
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
477+
poolLp.set("kernel_size", DictValue::arrayInt(&kernel_size[0], kernel_size.size()));
478+
479+
node_proto.set_output(0, poolLp.name);
480+
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
467481
}
482+
483+
layerParams.type = "Reshape";
484+
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
485+
486+
node_proto.set_input(0, node_proto.output(0));
487+
node_proto.set_output(0, layerParams.name);
468488
}
469489
}
470490
else if (layer_type == "Slice")
@@ -1001,15 +1021,10 @@ void ONNXImporter::populateNet(Net dstNet)
10011021
{
10021022
Mat inp0 = getBlob(node_proto, constBlobs, 0);
10031023
Mat inp1 = getBlob(node_proto, constBlobs, 1);
1004-
if (inp0.size != inp1.size)
1024+
if (inp0.size != inp1.size && inp1.total() != 1)
10051025
CV_Error(Error::StsNotImplemented, "Constant multiply with different shapes");
10061026

1007-
Mat out;
1008-
if (isDiv)
1009-
divide(inp0, inp1, out);
1010-
else
1011-
multiply(inp0, inp1, out);
1012-
1027+
Mat out = isDiv ? inp0 / inp1 : inp0.mul(inp1);
10131028
out = out.reshape(1, inp0.dims, inp0.size);
10141029
out.dims = inp0.dims; // to workaround dims == 1
10151030
addConstant(layerParams.name, out, constBlobs, outShapes);
@@ -1180,9 +1195,45 @@ void ONNXImporter::populateNet(Net dstNet)
11801195
Mat newShapeMat = getBlob(node_proto, constBlobs, 1);
11811196
MatShape targetShape(newShapeMat.ptr<int>(), newShapeMat.ptr<int>() + newShapeMat.total());
11821197

1183-
shapeIt = outShapes.find(node_proto.input(0));
1184-
CV_Assert(shapeIt != outShapes.end());
1185-
MatShape inpShape = shapeIt->second;
1198+
MatShape inpShape;
1199+
bool haveVariables = constBlobs.find(node_proto.input(0)) == constBlobs.end();
1200+
if (haveVariables)
1201+
{
1202+
shapeIt = outShapes.find(node_proto.input(0));
1203+
CV_Assert(shapeIt != outShapes.end());
1204+
inpShape = shapeIt->second;
1205+
}
1206+
else
1207+
{
1208+
inpShape = shape(getBlob(node_proto, constBlobs, 0));
1209+
}
1210+
1211+
String srcName = node_proto.input(0);
1212+
// Unsqueeze and repeat along new axis
1213+
if (targetShape.size() == inpShape.size() + 1)
1214+
{
1215+
for (int i = 0; i < targetShape.size(); i++)
1216+
{
1217+
if (targetShape[i] == -1 && i < inpShape.size())
1218+
targetShape[i] = inpShape[i];
1219+
else if (i < inpShape.size() && targetShape[i] != inpShape[i])
1220+
inpShape.insert(inpShape.begin() + i, 1);
1221+
}
1222+
if (haveVariables)
1223+
{
1224+
LayerParams reshapeLp;
1225+
reshapeLp.name = layerParams.name + "/reshape";
1226+
reshapeLp.type = "Reshape";
1227+
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
1228+
reshapeLp.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1229+
1230+
opencv_onnx::NodeProto proto;
1231+
proto.add_input(node_proto.input(0));
1232+
proto.add_output(reshapeLp.name);
1233+
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
1234+
srcName = reshapeLp.name;
1235+
}
1236+
}
11861237
CV_CheckEQ(inpShape.size(), targetShape.size(), "Unsupported Expand op with different dims");
11871238

11881239
std::vector<int> broadcast_axes;
@@ -1197,6 +1248,19 @@ void ONNXImporter::populateNet(Net dstNet)
11971248
}
11981249
}
11991250

1251+
if (!haveVariables)
1252+
{
1253+
if (broadcast_axes.size() != 1)
1254+
CV_Error(Error::StsNotImplemented, "Expand op doesn't support multiple axes for constant input");
1255+
1256+
Mat input = getBlob(node_proto, constBlobs, 0);
1257+
input = input.reshape(0, total(inpShape, 0, broadcast_axes[0]));
1258+
Mat output = cv::repeat(input, 1, targetShape[broadcast_axes[0]]);
1259+
output = output.reshape(0, targetShape);
1260+
addConstant(layerParams.name, output, constBlobs, outShapes);
1261+
continue;
1262+
}
1263+
12001264
if (broadcast_axes.size() == 2 &&
12011265
broadcast_axes[0] == broadcast_axes[1] - 1 && broadcast_axes[1] == inpShape.size() - 1)
12021266
{
@@ -1231,6 +1295,7 @@ void ONNXImporter::populateNet(Net dstNet)
12311295
CV_Assert(layer_id.find(copyLP.name) == layer_id.end());
12321296
input_names.push_back(copyLP.name);
12331297

1298+
node_proto.set_input(0, srcName);
12341299
node_proto.set_output(0, copyLP.name);
12351300
addLayer(dstNet, copyLP, node_proto, layer_id, outShapes);
12361301
}
@@ -1241,6 +1306,7 @@ void ONNXImporter::populateNet(Net dstNet)
12411306
}
12421307
layerParams.set("axis", broadcast_axes[0]);
12431308
layerParams.type = "Concat";
1309+
node_proto.set_output(0, layerParams.name);
12441310
}
12451311
else
12461312
CV_Error(Error::StsNotImplemented, "Unsupported Expand op");

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ TEST_P(Test_ONNX_layers, ReduceMean)
257257
testONNXModels("reduce_mean_axis2");
258258
}
259259

260+
TEST_P(Test_ONNX_layers, ReduceSum)
261+
{
262+
testONNXModels("reduce_sum");
263+
}
264+
260265
TEST_P(Test_ONNX_layers, ReduceMean3D)
261266
{
262267
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)
@@ -417,6 +422,7 @@ TEST_P(Test_ONNX_layers, Expand)
417422
{
418423
testONNXModels("expand_batch");
419424
testONNXModels("expand_channels");
425+
testONNXModels("expand_neg_batch");
420426
}
421427

422428
TEST_P(Test_ONNX_layers, ExpandHW)

0 commit comments

Comments
 (0)