Skip to content

Commit 8380781

Browse files
committed
Merge pull request opencv#18299 from l-bat:onnx_reduce_max
2 parents 6b67470 + b542a18 commit 8380781

File tree

2 files changed

+45
-9
lines changed

2 files changed

+45
-9
lines changed

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
392392
layerParams.set("ave_pool_padded_area", framework_name == "pytorch");
393393
}
394394
else if (layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool" ||
395-
layer_type == "ReduceMean" || layer_type == "ReduceSum")
395+
layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax")
396396
{
397397
CV_Assert(node_proto.input_size() == 1);
398398
layerParams.type = "Pooling";
399399
String pool;
400-
if (layer_type == "GlobalMaxPool")
400+
if (layer_type == "GlobalMaxPool" || layer_type == "ReduceMax")
401401
pool = "MAX";
402402
else if (layer_type == "ReduceSum")
403403
pool = "SUM";
404404
else
405405
pool = "AVE";
406406
layerParams.set("pool", pool);
407-
layerParams.set("global_pooling", layer_type == "GlobalAveragePool" || layer_type == "GlobalMaxPool");
408-
if (layer_type == "ReduceMean" || layer_type == "ReduceSum")
407+
layerParams.set("global_pooling", !layerParams.has("axes"));
408+
if (layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
409409
{
410-
if (!layerParams.has("axes"))
411-
CV_Error(Error::StsNotImplemented, "Unsupported mode of " + layer_type + " operation.");
412-
413410
MatShape inpShape = outShapes[node_proto.input(0)];
414411
DictValue axes = layerParams.get("axes");
415412
bool keepdims = layerParams.get<int>("keepdims");
@@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
487484
layerParams.type = "Reshape";
488485
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], targetShape.size()));
489486

487+
node_proto.set_input(0, node_proto.output(0));
488+
node_proto.set_output(0, layerParams.name);
489+
}
490+
else if (!layerParams.has("axes") && (layer_type == "ReduceMean" || layer_type == "ReduceSum" || layer_type == "ReduceMax"))
491+
{
492+
CV_CheckEQ(layerParams.get<int>("keepdims"), 0, (layer_type + " layer only supports keepdims = false").c_str());
493+
LayerParams reshapeLp;
494+
reshapeLp.name = layerParams.name + "/reshape";
495+
reshapeLp.type = "Reshape";
496+
CV_Assert(layer_id.find(reshapeLp.name) == layer_id.end());
497+
int newShape[] = {1, 1, 1, -1};
498+
reshapeLp.set("dim", DictValue::arrayInt(&newShape[0], 4));
499+
500+
opencv_onnx::NodeProto proto;
501+
proto.add_input(node_proto.input(0));
502+
proto.add_output(reshapeLp.name);
503+
addLayer(dstNet, reshapeLp, proto, layer_id, outShapes);
504+
505+
LayerParams poolLp = layerParams;
506+
poolLp.name = layerParams.name + "/pool";
507+
CV_Assert(layer_id.find(poolLp.name) == layer_id.end());
508+
509+
node_proto.set_input(0, reshapeLp.name);
510+
node_proto.set_output(0, poolLp.name);
511+
addLayer(dstNet, poolLp, node_proto, layer_id, outShapes);
512+
513+
layerParams.type = "Reshape";
514+
int targetShape[] = {1};
515+
layerParams.set("dim", DictValue::arrayInt(&targetShape[0], 1));
516+
490517
node_proto.set_input(0, node_proto.output(0));
491518
node_proto.set_output(0, layerParams.name);
492519
}
@@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
14271454
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break;
14281455
default: type = blob.type();
14291456
}
1430-
blob.convertTo(blob, type);
1431-
addConstant(layerParams.name, blob, constBlobs, outShapes);
1457+
Mat dst;
1458+
blob.convertTo(dst, type);
1459+
dst.dims = blob.dims;
1460+
addConstant(layerParams.name, dst, constBlobs, outShapes);
14321461
continue;
14331462
}
14341463
else
@@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
14771506
{
14781507
outShape.erase(outShape.begin() + axis);
14791508
out.reshape(0, outShape);
1509+
} else {
1510+
out.dims = 1;
14801511
}
14811512
addConstant(layerParams.name, out, constBlobs, outShapes);
14821513
continue;

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,11 @@ TEST_P(Test_ONNX_layers, ReduceSum)
262262
testONNXModels("reduce_sum");
263263
}
264264

265+
TEST_P(Test_ONNX_layers, ReduceMaxGlobal)
266+
{
267+
testONNXModels("reduce_max");
268+
}
269+
265270
TEST_P(Test_ONNX_layers, ReduceMean3D)
266271
{
267272
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target != DNN_TARGET_CPU)

0 commit comments

Comments
 (0)