Skip to content

Commit 7894cd3

Browse files
author
Anastasia Murzova
committed
Aligned TF Reshape layer behaviour
1 parent 2a808ae commit 7894cd3

File tree

2 files changed

+66
-14
lines changed

2 files changed

+66
-14
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 56 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,22 @@ DataLayout getDataLayout(
295295
return it != data_layouts.end() ? it->second : DATA_LAYOUT_UNKNOWN;
296296
}
297297

298+
static
299+
bool hasAllOnes(const Mat &inputs, int startPos, int endPos)
300+
{
301+
CV_CheckLE(inputs.dims, 2, "");
302+
CV_CheckGE(startPos, 0, "");
303+
CV_CheckLE(startPos, endPos, "");
304+
CV_CheckLT((size_t)endPos, inputs.total(), "");
305+
306+
for (int i = startPos; i < endPos; i++)
307+
{
308+
if (inputs.at<int>(i) != 1 || inputs.at<int>(i)!= -1)
309+
return false;
310+
}
311+
return true;
312+
}
313+
298314
void setStrides(LayerParams &layerParams, const tensorflow::NodeDef &layer)
299315
{
300316
if (hasLayerAttr(layer, "strides"))
@@ -490,6 +506,9 @@ class TFImporter
490506
std::map<String, Mat> sharedWeights;
491507

492508
std::map<String, int> layer_id;
509+
510+
private:
511+
void addPermuteLayer(const int* order, const std::string& permName, Pin& inpId);
493512
};
494513

495514
TFImporter::TFImporter(Net& net, const char *model, const char *config)
@@ -895,6 +914,17 @@ void TFImporter::populateNet()
895914
CV_LOG_DEBUG(NULL, "DNN/TF: ===================== Import completed =====================");
896915
}
897916

917+
void TFImporter::addPermuteLayer(const int* order, const std::string& permName, Pin& inpId)
918+
{
919+
LayerParams permLP;
920+
permLP.set("order", DictValue::arrayInt<const int*>(order, 4));
921+
CV_Assert(layer_id.find(permName) == layer_id.end());
922+
int permId = dstNet.addLayer(permName, "Permute", permLP);
923+
layer_id[permName] = permId;
924+
connect(layer_id, dstNet, inpId, permId, 0);
925+
inpId = Pin(permName);
926+
}
927+
898928
void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
899929
{
900930
tensorflow::NodeDef layer = layer_;
@@ -1276,37 +1306,49 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
12761306
if (value_id.find(layer.input(1)) != value_id.end())
12771307
{
12781308
Mat newShape = getTensorContent(getConstBlob(layer, value_id, 1));
1279-
if (newShape.total() == 4)
1309+
int newShapeSize = newShape.total();
1310+
bool hasSwap = false;
1311+
if (newShapeSize == 4 && hasAllOnes(newShape, 0, 2))
12801312
{
12811313
// NHWC->NCHW
12821314
std::swap(*newShape.ptr<int32_t>(0, 2), *newShape.ptr<int32_t>(0, 3));
12831315
std::swap(*newShape.ptr<int32_t>(0, 1), *newShape.ptr<int32_t>(0, 2));
1316+
hasSwap = true;
12841317
}
12851318
if (inpLayout == DATA_LAYOUT_NHWC)
12861319
{
1287-
if (newShape.total() != 4 || newShape.at<int>(1) == 1)
1320+
if (newShapeSize >= 2 || newShape.at<int>(1) == 1)
12881321
{
1289-
LayerParams permLP;
12901322
int order[] = {0, 2, 3, 1}; // From OpenCV's NCHW to NHWC.
1291-
permLP.set("order", DictValue::arrayInt<int*>(order, 4));
1292-
1293-
std::string permName = name + "/nchw";
1294-
CV_Assert(layer_id.find(permName) == layer_id.end());
1295-
int permId = dstNet.addLayer(permName, "Permute", permLP);
1296-
layer_id[permName] = permId;
1297-
connect(layer_id, dstNet, inpId, permId, 0);
1298-
inpId = Pin(permName);
1299-
inpLayout = DATA_LAYOUT_NCHW;
1323+
addPermuteLayer(order, name + "/nhwc", inpId);
1324+
if (newShapeSize < 4)
1325+
{
1326+
inpLayout = DATA_LAYOUT_NCHW;
1327+
}
1328+
else
1329+
{
1330+
inpLayout = DATA_LAYOUT_NHWC;
1331+
}
13001332
}
13011333
}
1302-
layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShape.total()));
1334+
layerParams.set("dim", DictValue::arrayInt<int*>(newShape.ptr<int>(), newShapeSize));
13031335

13041336
int id = dstNet.addLayer(name, "Reshape", layerParams);
13051337
layer_id[name] = id;
13061338

13071339
// one input only
13081340
connect(layer_id, dstNet, inpId, id, 0);
1309-
data_layouts[name] = newShape.total() == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
1341+
inpId = Pin(name);
1342+
1343+
if ((inpLayout == DATA_LAYOUT_NHWC || inpLayout == DATA_LAYOUT_UNKNOWN || inpLayout == DATA_LAYOUT_PLANAR) &&
1344+
newShapeSize == 4 && !hasSwap)
1345+
{
1346+
int order[] = {0, 3, 1, 2}; // Transform back to OpenCV's NCHW.
1347+
addPermuteLayer(order, name + "/nchw", inpId);
1348+
inpLayout = DATA_LAYOUT_NCHW;
1349+
}
1350+
1351+
data_layouts[name] = newShapeSize == 2 ? DATA_LAYOUT_PLANAR : inpLayout;
13101352
}
13111353
else
13121354
{

modules/dnn/test/test_tf_importer.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,16 @@ TEST_P(Test_TensorFlow_layers, unfused_flatten)
457457
runTensorFlowNet("unfused_flatten_unknown_batch");
458458
}
459459

460+
TEST_P(Test_TensorFlow_layers, reshape_layer)
461+
{
462+
runTensorFlowNet("reshape_layer");
463+
}
464+
465+
TEST_P(Test_TensorFlow_layers, reshape_nchw)
466+
{
467+
runTensorFlowNet("reshape_nchw");
468+
}
469+
460470
TEST_P(Test_TensorFlow_layers, leaky_relu)
461471
{
462472
#if defined(INF_ENGINE_RELEASE) && INF_ENGINE_VER_MAJOR_EQ(2018050000)

0 commit comments

Comments
 (0)