Skip to content

Commit 5444a6b

Browse files
committed
Merge pull request opencv#17890 from l-bat:onnx_gather
2 parents 284d26d + a35d4f9 commit 5444a6b

File tree

2 files changed

+56
-13
lines changed

2 files changed

+56
-13
lines changed

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet)
13421342
else if (layer_type == "Gather")
13431343
{
13441344
CV_Assert(node_proto.input_size() == 2);
1345-
Mat input = getBlob(node_proto, constBlobs, 0);
13461345
Mat indexMat = getBlob(node_proto, constBlobs, 1);
13471346
CV_Assert_N(indexMat.type() == CV_32S, indexMat.total() == 1);
13481347
int index = indexMat.at<int>(0);
1348+
int axis = layerParams.get<int>("axis", 0);
13491349

1350-
Mat out;
1351-
if (layerParams.has("axis"))
1350+
if ((constBlobs.find(node_proto.input(0)) != constBlobs.end()))
13521351
{
1353-
int axis = layerParams.get<int>("axis");
1354-
1352+
Mat input = getBlob(node_proto, constBlobs, 0);
1353+
Mat out;
13551354
std::vector<cv::Range> ranges(input.dims, Range::all());
13561355
ranges[axis] = Range(index, index + 1);
13571356

13581357
out = input(ranges);
1358+
MatShape outShape = shape(out);
1359+
if (outShape.size() > 1)
1360+
{
1361+
outShape.erase(outShape.begin() + axis);
1362+
out.reshape(0, outShape);
1363+
}
1364+
addConstant(layerParams.name, out, constBlobs, outShapes);
1365+
continue;
13591366
}
13601367
else
13611368
{
1362-
CV_Assert(index < input.total());
1363-
const int dims = input.dims;
1364-
input = input.reshape(1, 1);
1365-
input.dims = 2;
1366-
out = input.reshape(1, 1).colRange(index, index + 1);
1367-
out.dims = dims;
1369+
shapeIt = outShapes.find(node_proto.input(0));
1370+
CV_Assert(shapeIt != outShapes.end());
1371+
MatShape inpShape = shapeIt->second;
1372+
1373+
LayerParams sliceLp;
1374+
sliceLp.type = "Slice";
1375+
sliceLp.name = inpShape.size() > 1 ? layerParams.name + "/slice" : layerParams.name;
1376+
std::vector<int> begin(inpShape.size(), 0);
1377+
std::vector<int> end(inpShape.size(), -1);
1378+
begin[axis] = index;
1379+
end[axis] = index + 1;
1380+
1381+
cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt(begin.data(), begin.size());
1382+
cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt(end.data(), end.size());
1383+
sliceLp.set("begin", paramBegin);
1384+
sliceLp.set("end", paramEnd);
1385+
1386+
if (inpShape.size() > 1)
1387+
{
1388+
opencv_onnx::NodeProto proto;
1389+
proto.add_input(node_proto.input(0));
1390+
proto.add_output(sliceLp.name);
1391+
addLayer(dstNet, sliceLp, proto, layer_id, outShapes);
1392+
1393+
inpShape.erase(inpShape.begin() + axis);
1394+
layerParams.type = "Reshape";
1395+
layerParams.set("dim", DictValue::arrayInt(&inpShape[0], inpShape.size()));
1396+
node_proto.set_input(0, sliceLp.name);
1397+
}
1398+
else
1399+
{
1400+
layerParams = sliceLp;
1401+
}
13681402
}
1369-
addConstant(layerParams.name, out, constBlobs, outShapes);
1370-
continue;
13711403
}
13721404
else if (layer_type == "Concat")
13731405
{

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ TEST_P(Test_ONNX_layers, Convolution)
111111
testONNXModels("convolution");
112112
}
113113

114+
TEST_P(Test_ONNX_layers, Gather)
115+
{
116+
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019 && target == DNN_TARGET_MYRIAD)
117+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_MYRIAD, CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
118+
testONNXModels("gather");
119+
// GPU plugin unsupported slice for constant
120+
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH && (target == DNN_TARGET_OPENCL || target == DNN_TARGET_OPENCL_FP16))
121+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_OPENCL, CV_TEST_TAG_DNN_SKIP_IE_OPENCL_FP16, CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
122+
testONNXModels("gather_scalar", npy, 0, 0, false, false);
123+
}
124+
114125
TEST_P(Test_ONNX_layers, Convolution3D)
115126
{
116127
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_LT(2019010000)

0 commit comments

Comments
 (0)