@@ -295,6 +295,22 @@ DataLayout getDataLayout(
295
295
return it != data_layouts.end () ? it->second : DATA_LAYOUT_UNKNOWN;
296
296
}
297
297
298
+ static
299
+ bool hasAllOnes (const Mat &inputs, int startPos, int endPos)
300
+ {
301
+ CV_CheckLE (inputs.dims , 2 , " " );
302
+ CV_CheckGE (startPos, 0 , " " );
303
+ CV_CheckLE (startPos, endPos, " " );
304
+ CV_CheckLT ((size_t )endPos, inputs.total (), " " );
305
+
306
+ for (int i = startPos; i < endPos; i++)
307
+ {
308
+ if (inputs.at <int >(i) != 1 || inputs.at <int >(i)!= -1 )
309
+ return false ;
310
+ }
311
+ return true ;
312
+ }
313
+
298
314
void setStrides (LayerParams &layerParams, const tensorflow::NodeDef &layer)
299
315
{
300
316
if (hasLayerAttr (layer, " strides" ))
@@ -490,6 +506,9 @@ class TFImporter
490
506
std::map<String, Mat> sharedWeights;
491
507
492
508
std::map<String, int > layer_id;
509
+
510
+ private:
511
+ void addPermuteLayer (const int * order, const std::string& permName, Pin& inpId);
493
512
};
494
513
495
514
TFImporter::TFImporter (Net& net, const char *model, const char *config)
@@ -895,6 +914,17 @@ void TFImporter::populateNet()
895
914
CV_LOG_DEBUG (NULL , " DNN/TF: ===================== Import completed =====================" );
896
915
}
897
916
917
+ void TFImporter::addPermuteLayer (const int * order, const std::string& permName, Pin& inpId)
918
+ {
919
+ LayerParams permLP;
920
+ permLP.set (" order" , DictValue::arrayInt<const int *>(order, 4 ));
921
+ CV_Assert (layer_id.find (permName) == layer_id.end ());
922
+ int permId = dstNet.addLayer (permName, " Permute" , permLP);
923
+ layer_id[permName] = permId;
924
+ connect (layer_id, dstNet, inpId, permId, 0 );
925
+ inpId = Pin (permName);
926
+ }
927
+
898
928
void TFImporter::parseNode (const tensorflow::NodeDef& layer_)
899
929
{
900
930
tensorflow::NodeDef layer = layer_;
@@ -1276,37 +1306,49 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
1276
1306
if (value_id.find (layer.input (1 )) != value_id.end ())
1277
1307
{
1278
1308
Mat newShape = getTensorContent (getConstBlob (layer, value_id, 1 ));
1279
- if (newShape.total () == 4 )
1309
+ int newShapeSize = newShape.total ();
1310
+ bool hasSwap = false ;
1311
+ if (newShapeSize == 4 && hasAllOnes (newShape, 0 , 2 ))
1280
1312
{
1281
1313
// NHWC->NCHW
1282
1314
std::swap (*newShape.ptr <int32_t >(0 , 2 ), *newShape.ptr <int32_t >(0 , 3 ));
1283
1315
std::swap (*newShape.ptr <int32_t >(0 , 1 ), *newShape.ptr <int32_t >(0 , 2 ));
1316
+ hasSwap = true ;
1284
1317
}
1285
1318
if (inpLayout == DATA_LAYOUT_NHWC)
1286
1319
{
1287
- if (newShape. total () != 4 || newShape.at <int >(1 ) == 1 )
1320
+ if (newShapeSize >= 2 || newShape.at <int >(1 ) == 1 )
1288
1321
{
1289
- LayerParams permLP;
1290
1322
int order[] = {0 , 2 , 3 , 1 }; // From OpenCV's NCHW to NHWC.
1291
- permLP. set ( " order" , DictValue::arrayInt< int *>(order, 4 ) );
1292
-
1293
- std::string permName = name + " /nchw " ;
1294
- CV_Assert (layer_id. find (permName) == layer_id. end ()) ;
1295
- int permId = dstNet. addLayer (permName, " Permute " , permLP);
1296
- layer_id[permName] = permId;
1297
- connect (layer_id, dstNet, inpId, permId, 0 );
1298
- inpId = Pin (permName) ;
1299
- inpLayout = DATA_LAYOUT_NCHW;
1323
+ addPermuteLayer ( order, name + " /nhwc " , inpId );
1324
+ if (newShapeSize < 4 )
1325
+ {
1326
+ inpLayout = DATA_LAYOUT_NCHW ;
1327
+ }
1328
+ else
1329
+ {
1330
+ inpLayout = DATA_LAYOUT_NHWC ;
1331
+ }
1300
1332
}
1301
1333
}
1302
- layerParams.set (" dim" , DictValue::arrayInt<int *>(newShape.ptr <int >(), newShape. total () ));
1334
+ layerParams.set (" dim" , DictValue::arrayInt<int *>(newShape.ptr <int >(), newShapeSize ));
1303
1335
1304
1336
int id = dstNet.addLayer (name, " Reshape" , layerParams);
1305
1337
layer_id[name] = id;
1306
1338
1307
1339
// one input only
1308
1340
connect (layer_id, dstNet, inpId, id, 0 );
1309
- data_layouts[name] = newShape.total () == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
1341
+ inpId = Pin (name);
1342
+
1343
+ if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
1344
+ newShapeSize == 4 && !hasSwap)
1345
+ {
1346
+ int order[] = {0 , 3 , 1 , 2 }; // Transform back to OpenCV's NCHW.
1347
+ addPermuteLayer (order, name + " /nchw" , inpId);
1348
+ inpLayout = DATA_LAYOUT_NCHW;
1349
+ }
1350
+
1351
+ data_layouts[name] = newShapeSize == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
1310
1352
}
1311
1353
else
1312
1354
{
0 commit comments