@@ -387,26 +387,42 @@ void ONNXImporter::populateNet(Net dstNet)
387
387
layerParams.set (" ceil_mode" , layerParams.has (" pad_mode" ));
388
388
layerParams.set (" ave_pool_padded_area" , framework_name == " pytorch" );
389
389
}
390
- else if (layer_type == " GlobalAveragePool" || layer_type == " GlobalMaxPool" || layer_type == " ReduceMean" )
390
+ else if (layer_type == " GlobalAveragePool" || layer_type == " GlobalMaxPool" ||
391
+ layer_type == " ReduceMean" || layer_type == " ReduceSum" )
391
392
{
392
393
CV_Assert (node_proto.input_size () == 1 );
393
394
layerParams.type = " Pooling" ;
394
- layerParams.set (" pool" , layer_type == " GlobalMaxPool" ? " MAX" : " AVE" );
395
+ String pool;
396
+ if (layer_type == " GlobalMaxPool" )
397
+ pool = " MAX" ;
398
+ else if (layer_type == " ReduceSum" )
399
+ pool = " SUM" ;
400
+ else
401
+ pool = " AVE" ;
402
+ layerParams.set (" pool" , pool);
395
403
layerParams.set (" global_pooling" , layer_type == " GlobalAveragePool" || layer_type == " GlobalMaxPool" );
396
-
397
- if (layer_type == " ReduceMean" )
404
+ if (layer_type == " ReduceMean" || layer_type == " ReduceSum" )
398
405
{
399
- if (layerParams. get < int >( " keepdims " ) == 0 || !layerParams.has (" axes" ))
400
- CV_Error (Error::StsNotImplemented, " Unsupported mode of ReduceMean operation." );
406
+ if (!layerParams.has (" axes" ))
407
+ CV_Error (Error::StsNotImplemented, " Unsupported mode of " + layer_type + " operation." );
401
408
402
409
MatShape inpShape = outShapes[node_proto.input (0 )];
403
410
DictValue axes = layerParams.get (" axes" );
411
+ bool keepdims = layerParams.get <int >(" keepdims" );
412
+ MatShape targetShape = inpShape;
413
+ for (int i = 0 ; i < axes.size (); i++) {
414
+ int axis = clamp (axes.get <int >(i), inpShape.size ());
415
+ if (keepdims) {
416
+ targetShape[axis] = 1 ;
417
+ } else {
418
+ targetShape.erase (targetShape.begin () + axis);
419
+ }
420
+ }
421
+
404
422
if (inpShape.size () == 3 && axes.size () <= 2 )
405
423
{
406
- int axis = axes.get <int >(0 );
424
+ int axis = clamp ( axes.get <int >(0 ), inpShape. size () );
407
425
CV_CheckNE (axis, 0 , " " );
408
- outShapes[layerParams.name ] = inpShape;
409
- outShapes[layerParams.name ][axis] = 1 ;
410
426
411
427
LayerParams reshapeLp;
412
428
reshapeLp.name = layerParams.name + " /reshape" ;
@@ -426,13 +442,12 @@ void ONNXImporter::populateNet(Net dstNet)
426
442
avgLp.name = layerParams.name + " /avg" ;
427
443
avgLp.type = " Pooling" ;
428
444
CV_Assert (layer_id.find (avgLp.name ) == layer_id.end ());
429
- avgLp.set (" pool" , " ave " );
445
+ avgLp.set (" pool" , pool );
430
446
if (axes.size () == 2 )
431
447
{
432
- CV_CheckEQ (axes.get <int >(0 ), 1 , " Unsupported ReduceMean mode" );
433
- CV_CheckEQ (axes.get <int >(1 ), 2 , " Unsupported ReduceMean mode" );
448
+ CV_CheckEQ (clamp ( axes.get <int >(0 ), inpShape. size ()), 1 , ( " Unsupported " + layer_type + " mode" ). c_str () );
449
+ CV_CheckEQ (clamp ( axes.get <int >(1 ), inpShape. size ()), 2 , ( " Unsupported " + layer_type + " mode" ). c_str () );
434
450
avgLp.set (" global_pooling" , true );
435
- outShapes[layerParams.name ][axes.get <int >(1 )] = 1 ;
436
451
}
437
452
else
438
453
{
@@ -443,28 +458,33 @@ void ONNXImporter::populateNet(Net dstNet)
443
458
node_proto.set_input (0 , reshapeLp.name );
444
459
node_proto.set_output (0 , avgLp.name );
445
460
addLayer (dstNet, avgLp, node_proto, layer_id, outShapes);
446
-
447
- layerParams.type = " Flatten" ;
448
- layerParams.set (" axis" , 0 );
449
- layerParams.set (" end_axis" , 1 );
450
-
451
- node_proto.set_input (0 , avgLp.name );
452
- node_proto.set_output (0 , layerParams.name );
453
461
}
454
462
else
455
463
{
456
464
if (inpShape.size () != 4 && inpShape.size () != 5 )
457
- CV_Error (Error::StsNotImplemented, " Unsupported input shape of reduce_mean operation." );
465
+ CV_Error (Error::StsNotImplemented, " Unsupported input shape of " + layer_type + " operation." );
458
466
459
467
CV_Assert (axes.size () <= inpShape.size () - 2 );
460
468
std::vector<int > kernel_size (inpShape.size () - 2 , 1 );
461
469
for (int i = 0 ; i < axes.size (); i++) {
462
- int axis = axes.get <int >(i);
470
+ int axis = clamp ( axes.get <int >(i), inpShape. size () );
463
471
CV_Assert_N (axis >= 2 + i, axis < inpShape.size ());
464
472
kernel_size[axis - 2 ] = inpShape[axis];
465
473
}
466
- layerParams.set (" kernel_size" , DictValue::arrayInt (&kernel_size[0 ], kernel_size.size ()));
474
+ LayerParams poolLp = layerParams;
475
+ poolLp.name = layerParams.name + " /avg" ;
476
+ CV_Assert (layer_id.find (poolLp.name ) == layer_id.end ());
477
+ poolLp.set (" kernel_size" , DictValue::arrayInt (&kernel_size[0 ], kernel_size.size ()));
478
+
479
+ node_proto.set_output (0 , poolLp.name );
480
+ addLayer (dstNet, poolLp, node_proto, layer_id, outShapes);
467
481
}
482
+
483
+ layerParams.type = " Reshape" ;
484
+ layerParams.set (" dim" , DictValue::arrayInt (&targetShape[0 ], targetShape.size ()));
485
+
486
+ node_proto.set_input (0 , node_proto.output (0 ));
487
+ node_proto.set_output (0 , layerParams.name );
468
488
}
469
489
}
470
490
else if (layer_type == " Slice" )
@@ -1001,15 +1021,10 @@ void ONNXImporter::populateNet(Net dstNet)
1001
1021
{
1002
1022
Mat inp0 = getBlob (node_proto, constBlobs, 0 );
1003
1023
Mat inp1 = getBlob (node_proto, constBlobs, 1 );
1004
- if (inp0.size != inp1.size )
1024
+ if (inp0.size != inp1.size && inp1. total () != 1 )
1005
1025
CV_Error (Error::StsNotImplemented, " Constant multiply with different shapes" );
1006
1026
1007
- Mat out;
1008
- if (isDiv)
1009
- divide (inp0, inp1, out);
1010
- else
1011
- multiply (inp0, inp1, out);
1012
-
1027
+ Mat out = isDiv ? inp0 / inp1 : inp0.mul (inp1);
1013
1028
out = out.reshape (1 , inp0.dims , inp0.size );
1014
1029
out.dims = inp0.dims ; // to workaround dims == 1
1015
1030
addConstant (layerParams.name , out, constBlobs, outShapes);
@@ -1180,9 +1195,45 @@ void ONNXImporter::populateNet(Net dstNet)
1180
1195
Mat newShapeMat = getBlob (node_proto, constBlobs, 1 );
1181
1196
MatShape targetShape (newShapeMat.ptr <int >(), newShapeMat.ptr <int >() + newShapeMat.total ());
1182
1197
1183
- shapeIt = outShapes.find (node_proto.input (0 ));
1184
- CV_Assert (shapeIt != outShapes.end ());
1185
- MatShape inpShape = shapeIt->second ;
1198
+ MatShape inpShape;
1199
+ bool haveVariables = constBlobs.find (node_proto.input (0 )) == constBlobs.end ();
1200
+ if (haveVariables)
1201
+ {
1202
+ shapeIt = outShapes.find (node_proto.input (0 ));
1203
+ CV_Assert (shapeIt != outShapes.end ());
1204
+ inpShape = shapeIt->second ;
1205
+ }
1206
+ else
1207
+ {
1208
+ inpShape = shape (getBlob (node_proto, constBlobs, 0 ));
1209
+ }
1210
+
1211
+ String srcName = node_proto.input (0 );
1212
+ // Unsqueeze and repeat along new axis
1213
+ if (targetShape.size () == inpShape.size () + 1 )
1214
+ {
1215
+ for (int i = 0 ; i < targetShape.size (); i++)
1216
+ {
1217
+ if (targetShape[i] == -1 && i < inpShape.size ())
1218
+ targetShape[i] = inpShape[i];
1219
+ else if (i < inpShape.size () && targetShape[i] != inpShape[i])
1220
+ inpShape.insert (inpShape.begin () + i, 1 );
1221
+ }
1222
+ if (haveVariables)
1223
+ {
1224
+ LayerParams reshapeLp;
1225
+ reshapeLp.name = layerParams.name + " /reshape" ;
1226
+ reshapeLp.type = " Reshape" ;
1227
+ CV_Assert (layer_id.find (reshapeLp.name ) == layer_id.end ());
1228
+ reshapeLp.set (" dim" , DictValue::arrayInt (&inpShape[0 ], inpShape.size ()));
1229
+
1230
+ opencv_onnx::NodeProto proto;
1231
+ proto.add_input (node_proto.input (0 ));
1232
+ proto.add_output (reshapeLp.name );
1233
+ addLayer (dstNet, reshapeLp, proto, layer_id, outShapes);
1234
+ srcName = reshapeLp.name ;
1235
+ }
1236
+ }
1186
1237
CV_CheckEQ (inpShape.size (), targetShape.size (), " Unsupported Expand op with different dims" );
1187
1238
1188
1239
std::vector<int > broadcast_axes;
@@ -1197,6 +1248,19 @@ void ONNXImporter::populateNet(Net dstNet)
1197
1248
}
1198
1249
}
1199
1250
1251
+ if (!haveVariables)
1252
+ {
1253
+ if (broadcast_axes.size () != 1 )
1254
+ CV_Error (Error::StsNotImplemented, " Expand op doesn't support multiple axes for constant input" );
1255
+
1256
+ Mat input = getBlob (node_proto, constBlobs, 0 );
1257
+ input = input.reshape (0 , total (inpShape, 0 , broadcast_axes[0 ]));
1258
+ Mat output = cv::repeat (input, 1 , targetShape[broadcast_axes[0 ]]);
1259
+ output = output.reshape (0 , targetShape);
1260
+ addConstant (layerParams.name , output, constBlobs, outShapes);
1261
+ continue ;
1262
+ }
1263
+
1200
1264
if (broadcast_axes.size () == 2 &&
1201
1265
broadcast_axes[0 ] == broadcast_axes[1 ] - 1 && broadcast_axes[1 ] == inpShape.size () - 1 )
1202
1266
{
@@ -1231,6 +1295,7 @@ void ONNXImporter::populateNet(Net dstNet)
1231
1295
CV_Assert (layer_id.find (copyLP.name ) == layer_id.end ());
1232
1296
input_names.push_back (copyLP.name );
1233
1297
1298
+ node_proto.set_input (0 , srcName);
1234
1299
node_proto.set_output (0 , copyLP.name );
1235
1300
addLayer (dstNet, copyLP, node_proto, layer_id, outShapes);
1236
1301
}
@@ -1241,6 +1306,7 @@ void ONNXImporter::populateNet(Net dstNet)
1241
1306
}
1242
1307
layerParams.set (" axis" , broadcast_axes[0 ]);
1243
1308
layerParams.type = " Concat" ;
1309
+ node_proto.set_output (0 , layerParams.name );
1244
1310
}
1245
1311
else
1246
1312
CV_Error (Error::StsNotImplemented, " Unsupported Expand op" );
0 commit comments