@@ -1342,32 +1342,64 @@ void ONNXImporter::populateNet(Net dstNet)
1342
1342
else if (layer_type == " Gather" )
1343
1343
{
1344
1344
CV_Assert (node_proto.input_size () == 2 );
1345
- Mat input = getBlob (node_proto, constBlobs, 0 );
1346
1345
Mat indexMat = getBlob (node_proto, constBlobs, 1 );
1347
1346
CV_Assert_N (indexMat.type () == CV_32S, indexMat.total () == 1 );
1348
1347
int index = indexMat.at <int >(0 );
1348
+ int axis = layerParams.get <int >(" axis" , 0 );
1349
1349
1350
- Mat out;
1351
- if (layerParams.has (" axis" ))
1350
+ if ((constBlobs.find (node_proto.input (0 )) != constBlobs.end ()))
1352
1351
{
1353
- int axis = layerParams. get < int >( " axis " );
1354
-
1352
+ Mat input = getBlob (node_proto, constBlobs, 0 );
1353
+ Mat out;
1355
1354
std::vector<cv::Range> ranges (input.dims , Range::all ());
1356
1355
ranges[axis] = Range (index, index + 1 );
1357
1356
1358
1357
out = input (ranges);
1358
+ MatShape outShape = shape (out);
1359
+ if (outShape.size () > 1 )
1360
+ {
1361
+ outShape.erase (outShape.begin () + axis);
1362
+ out.reshape (0 , outShape);
1363
+ }
1364
+ addConstant (layerParams.name , out, constBlobs, outShapes);
1365
+ continue ;
1359
1366
}
1360
1367
else
1361
1368
{
1362
- CV_Assert (index < input.total ());
1363
- const int dims = input.dims ;
1364
- input = input.reshape (1 , 1 );
1365
- input.dims = 2 ;
1366
- out = input.reshape (1 , 1 ).colRange (index, index + 1 );
1367
- out.dims = dims;
1369
+ shapeIt = outShapes.find (node_proto.input (0 ));
1370
+ CV_Assert (shapeIt != outShapes.end ());
1371
+ MatShape inpShape = shapeIt->second ;
1372
+
1373
+ LayerParams sliceLp;
1374
+ sliceLp.type = " Slice" ;
1375
+ sliceLp.name = inpShape.size () > 1 ? layerParams.name + " /slice" : layerParams.name ;
1376
+ std::vector<int > begin (inpShape.size (), 0 );
1377
+ std::vector<int > end (inpShape.size (), -1 );
1378
+ begin[axis] = index;
1379
+ end[axis] = index + 1 ;
1380
+
1381
+ cv::dnn::DictValue paramBegin = cv::dnn::DictValue::arrayInt (begin.data (), begin.size ());
1382
+ cv::dnn::DictValue paramEnd = cv::dnn::DictValue::arrayInt (end.data (), end.size ());
1383
+ sliceLp.set (" begin" , paramBegin);
1384
+ sliceLp.set (" end" , paramEnd);
1385
+
1386
+ if (inpShape.size () > 1 )
1387
+ {
1388
+ opencv_onnx::NodeProto proto;
1389
+ proto.add_input (node_proto.input (0 ));
1390
+ proto.add_output (sliceLp.name );
1391
+ addLayer (dstNet, sliceLp, proto, layer_id, outShapes);
1392
+
1393
+ inpShape.erase (inpShape.begin () + axis);
1394
+ layerParams.type = " Reshape" ;
1395
+ layerParams.set (" dim" , DictValue::arrayInt (&inpShape[0 ], inpShape.size ()));
1396
+ node_proto.set_input (0 , sliceLp.name );
1397
+ }
1398
+ else
1399
+ {
1400
+ layerParams = sliceLp;
1401
+ }
1368
1402
}
1369
- addConstant (layerParams.name , out, constBlobs, outShapes);
1370
- continue ;
1371
1403
}
1372
1404
else if (layer_type == " Concat" )
1373
1405
{
0 commit comments