Skip to content

Commit bd43e4f

Browse files
committed
Merge pull request opencv#14251 from dkurt:dnn_tf_manage_switch
2 parents 1fb93b6 + ec41a48 commit bd43e4f

File tree

4 files changed

+95
-1
lines changed

4 files changed

+95
-1
lines changed

modules/dnn/src/tensorflow/tf_graph_simplifier.cpp

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#ifdef HAVE_PROTOBUF
1111

1212
#include "tf_graph_simplifier.hpp"
13+
#include <queue>
1314

1415
namespace cv { namespace dnn {
1516
CV__DNN_EXPERIMENTAL_NS_BEGIN
@@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
883884
nodesToAdd.pop_back();
884885

885886
permIds.push_back(nodeToAdd);
886-
// std::cout << net.node(nodeToAdd).name() << '\n';
887887

888888
for (int i = 0; i < edges[nodeToAdd].size(); ++i)
889889
{
@@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
902902
permute(net.mutable_node(), permIds);
903903
}
904904

905+
// Remove training switches (Switch and Merge nodes and corresponding subgraphs).
906+
void removePhaseSwitches(tensorflow::GraphDef& net)
907+
{
908+
std::vector<int> nodesToRemove;
909+
std::map<std::string, int> nodesMap;
910+
std::map<std::string, int>::iterator nodesMapIt;
911+
std::queue<int> mergeOpSubgraphNodes;
912+
for (int i = 0; i < net.node_size(); ++i)
913+
{
914+
const tensorflow::NodeDef& node = net.node(i);
915+
nodesMap.insert(std::make_pair(node.name(), i));
916+
if (node.op() == "Switch" || node.op() == "Merge")
917+
{
918+
CV_Assert(node.input_size() > 0);
919+
// Replace consumers' inputs.
920+
for (int j = 0; j < net.node_size(); ++j)
921+
{
922+
tensorflow::NodeDef* consumer = net.mutable_node(j);
923+
for (int k = 0; k < consumer->input_size(); ++k)
924+
{
925+
std::string inpName = consumer->input(k);
926+
inpName = inpName.substr(0, inpName.rfind(':'));
927+
if (inpName == node.name())
928+
{
929+
consumer->set_input(k, node.input(0));
930+
}
931+
}
932+
}
933+
nodesToRemove.push_back(i);
934+
if (node.op() == "Merge")
935+
mergeOpSubgraphNodes.push(i);
936+
}
937+
}
938+
939+
std::vector<int> numConsumers(net.node_size(), 0);
940+
for (int i = 0; i < net.node_size(); ++i)
941+
{
942+
const tensorflow::NodeDef& node = net.node(i);
943+
for (int j = 0; j < node.input_size(); ++j)
944+
{
945+
std::string inpName = node.input(j);
946+
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
947+
nodesMapIt = nodesMap.find(inpName);
948+
CV_Assert(nodesMapIt != nodesMap.end());
949+
numConsumers[nodesMapIt->second] += 1;
950+
}
951+
}
952+
953+
// Remove subgraphs of unused nodes which are terminated by Merge nodes.
954+
while (!mergeOpSubgraphNodes.empty())
955+
{
956+
const tensorflow::NodeDef& node = net.node(mergeOpSubgraphNodes.front());
957+
mergeOpSubgraphNodes.pop();
958+
for (int i = 0; i < node.input_size(); ++i)
959+
{
960+
std::string inpName = node.input(i);
961+
inpName = inpName.substr(1 + (int)inpName.find('^'), inpName.rfind(':'));
962+
nodesMapIt = nodesMap.find(inpName);
963+
CV_Assert(nodesMapIt != nodesMap.end());
964+
965+
int inpNodeId = nodesMapIt->second;
966+
if (numConsumers[inpNodeId] == 1)
967+
{
968+
mergeOpSubgraphNodes.push(inpNodeId);
969+
nodesToRemove.push_back(inpNodeId);
970+
}
971+
else if (numConsumers[inpNodeId] > 0)
972+
numConsumers[inpNodeId] -= 1;
973+
}
974+
}
975+
std::sort(nodesToRemove.begin(), nodesToRemove.end());
976+
for (int i = nodesToRemove.size() - 1; i >= 0; --i)
977+
{
978+
if (nodesToRemove[i] < net.node_size()) // Ids might be repeated.
979+
net.mutable_node()->DeleteSubrange(nodesToRemove[i], 1);
980+
}
981+
}
982+
983+
905984
CV__DNN_EXPERIMENTAL_NS_END
906985
}} // namespace dnn, namespace cv
907986

modules/dnn/src/tensorflow/tf_graph_simplifier.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ void releaseTensor(tensorflow::TensorProto* tensor);
2727

2828
void sortByExecutionOrder(tensorflow::GraphDef& net);
2929

30+
void removePhaseSwitches(tensorflow::GraphDef& net);
31+
3032
CV__DNN_EXPERIMENTAL_NS_END
3133
}} // namespace dnn, namespace cv
3234

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -657,6 +657,9 @@ static int predictOutputDataLayout(const tensorflow::GraphDef& net,
657657

658658
void TFImporter::populateNet(Net dstNet)
659659
{
660+
if (!netTxt.ByteSize())
661+
removePhaseSwitches(netBin);
662+
660663
RemoveIdentityOps(netBin);
661664
RemoveIdentityOps(netTxt);
662665

modules/dnn/test/test_tf_importer.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,16 @@ TEST_P(Test_TensorFlow_layers, batch_norm)
185185
runTensorFlowNet("mvn_batch_norm_1x1");
186186
}
187187

188+
TEST_P(Test_TensorFlow_layers, slim_batch_norm)
189+
{
190+
if (backend == DNN_BACKEND_INFERENCE_ENGINE)
191+
throw SkipTestException("Test is disabled for DLIE");
192+
// Output values range: [-40.0597, 207.827]
193+
double l1 = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.041 : default_l1;
194+
double lInf = (target == DNN_TARGET_OPENCL_FP16 || target == DNN_TARGET_MYRIAD) ? 0.33 : default_lInf;
195+
runTensorFlowNet("slim_batch_norm", false, l1, lInf);
196+
}
197+
188198
TEST_P(Test_TensorFlow_layers, pooling)
189199
{
190200
runTensorFlowNet("max_pool_even");

0 commit comments

Comments
 (0)