@@ -392,24 +392,21 @@ void ONNXImporter::populateNet(Net dstNet)
392
392
layerParams.set (" ave_pool_padded_area" , framework_name == " pytorch" );
393
393
}
394
394
else if (layer_type == " GlobalAveragePool" || layer_type == " GlobalMaxPool" ||
395
- layer_type == " ReduceMean" || layer_type == " ReduceSum" )
395
+ layer_type == " ReduceMean" || layer_type == " ReduceSum" || layer_type == " ReduceMax " )
396
396
{
397
397
CV_Assert (node_proto.input_size () == 1 );
398
398
layerParams.type = " Pooling" ;
399
399
String pool;
400
- if (layer_type == " GlobalMaxPool" )
400
+ if (layer_type == " GlobalMaxPool" || layer_type == " ReduceMax " )
401
401
pool = " MAX" ;
402
402
else if (layer_type == " ReduceSum" )
403
403
pool = " SUM" ;
404
404
else
405
405
pool = " AVE" ;
406
406
layerParams.set (" pool" , pool);
407
- layerParams.set (" global_pooling" , layer_type == " GlobalAveragePool " || layer_type == " GlobalMaxPool " );
408
- if (layer_type == " ReduceMean" || layer_type == " ReduceSum" )
407
+ layerParams.set (" global_pooling" , !layerParams. has ( " axes " ) );
408
+ if (layerParams. has ( " axes " ) && ( layer_type == " ReduceMean" || layer_type == " ReduceSum" || layer_type == " ReduceMax " ) )
409
409
{
410
- if (!layerParams.has (" axes" ))
411
- CV_Error (Error::StsNotImplemented, " Unsupported mode of " + layer_type + " operation." );
412
-
413
410
MatShape inpShape = outShapes[node_proto.input (0 )];
414
411
DictValue axes = layerParams.get (" axes" );
415
412
bool keepdims = layerParams.get <int >(" keepdims" );
@@ -487,6 +484,36 @@ void ONNXImporter::populateNet(Net dstNet)
487
484
layerParams.type = " Reshape" ;
488
485
layerParams.set (" dim" , DictValue::arrayInt (&targetShape[0 ], targetShape.size ()));
489
486
487
+ node_proto.set_input (0 , node_proto.output (0 ));
488
+ node_proto.set_output (0 , layerParams.name );
489
+ }
490
+ else if (!layerParams.has (" axes" ) && (layer_type == " ReduceMean" || layer_type == " ReduceSum" || layer_type == " ReduceMax" ))
491
+ {
492
+ CV_CheckEQ (layerParams.get <int >(" keepdims" ), 0 , (layer_type + " layer only supports keepdims = false" ).c_str ());
493
+ LayerParams reshapeLp;
494
+ reshapeLp.name = layerParams.name + " /reshape" ;
495
+ reshapeLp.type = " Reshape" ;
496
+ CV_Assert (layer_id.find (reshapeLp.name ) == layer_id.end ());
497
+ int newShape[] = {1 , 1 , 1 , -1 };
498
+ reshapeLp.set (" dim" , DictValue::arrayInt (&newShape[0 ], 4 ));
499
+
500
+ opencv_onnx::NodeProto proto;
501
+ proto.add_input (node_proto.input (0 ));
502
+ proto.add_output (reshapeLp.name );
503
+ addLayer (dstNet, reshapeLp, proto, layer_id, outShapes);
504
+
505
+ LayerParams poolLp = layerParams;
506
+ poolLp.name = layerParams.name + " /pool" ;
507
+ CV_Assert (layer_id.find (poolLp.name ) == layer_id.end ());
508
+
509
+ node_proto.set_input (0 , reshapeLp.name );
510
+ node_proto.set_output (0 , poolLp.name );
511
+ addLayer (dstNet, poolLp, node_proto, layer_id, outShapes);
512
+
513
+ layerParams.type = " Reshape" ;
514
+ int targetShape[] = {1 };
515
+ layerParams.set (" dim" , DictValue::arrayInt (&targetShape[0 ], 1 ));
516
+
490
517
node_proto.set_input (0 , node_proto.output (0 ));
491
518
node_proto.set_output (0 , layerParams.name );
492
519
}
@@ -1427,8 +1454,10 @@ void ONNXImporter::populateNet(Net dstNet)
1427
1454
case opencv_onnx::TensorProto_DataType_INT64: type = CV_32S; break ;
1428
1455
default : type = blob.type ();
1429
1456
}
1430
- blob.convertTo (blob, type);
1431
- addConstant (layerParams.name , blob, constBlobs, outShapes);
1457
+ Mat dst;
1458
+ blob.convertTo (dst, type);
1459
+ dst.dims = blob.dims ;
1460
+ addConstant (layerParams.name , dst, constBlobs, outShapes);
1432
1461
continue ;
1433
1462
}
1434
1463
else
@@ -1477,6 +1506,8 @@ void ONNXImporter::populateNet(Net dstNet)
1477
1506
{
1478
1507
outShape.erase (outShape.begin () + axis);
1479
1508
out.reshape (0 , outShape);
1509
+ } else {
1510
+ out.dims = 1 ;
1480
1511
}
1481
1512
addConstant (layerParams.name , out, constBlobs, outShapes);
1482
1513
continue ;
0 commit comments