10
10
#ifdef HAVE_PROTOBUF
11
11
12
12
#include " tf_graph_simplifier.hpp"
13
+ #include < queue>
13
14
14
15
namespace cv { namespace dnn {
15
16
CV__DNN_EXPERIMENTAL_NS_BEGIN
@@ -883,7 +884,6 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
883
884
nodesToAdd.pop_back ();
884
885
885
886
permIds.push_back (nodeToAdd);
886
- // std::cout << net.node(nodeToAdd).name() << '\n';
887
887
888
888
for (int i = 0 ; i < edges[nodeToAdd].size (); ++i)
889
889
{
@@ -902,6 +902,85 @@ void sortByExecutionOrder(tensorflow::GraphDef& net)
902
902
permute (net.mutable_node (), permIds);
903
903
}
904
904
905
+ // Remove training switches (Switch and Merge nodes and corresponding subgraphs).
906
+ void removePhaseSwitches (tensorflow::GraphDef& net)
907
+ {
908
+ std::vector<int > nodesToRemove;
909
+ std::map<std::string, int > nodesMap;
910
+ std::map<std::string, int >::iterator nodesMapIt;
911
+ std::queue<int > mergeOpSubgraphNodes;
912
+ for (int i = 0 ; i < net.node_size (); ++i)
913
+ {
914
+ const tensorflow::NodeDef& node = net.node (i);
915
+ nodesMap.insert (std::make_pair (node.name (), i));
916
+ if (node.op () == " Switch" || node.op () == " Merge" )
917
+ {
918
+ CV_Assert (node.input_size () > 0 );
919
+ // Replace consumers' inputs.
920
+ for (int j = 0 ; j < net.node_size (); ++j)
921
+ {
922
+ tensorflow::NodeDef* consumer = net.mutable_node (j);
923
+ for (int k = 0 ; k < consumer->input_size (); ++k)
924
+ {
925
+ std::string inpName = consumer->input (k);
926
+ inpName = inpName.substr (0 , inpName.rfind (' :' ));
927
+ if (inpName == node.name ())
928
+ {
929
+ consumer->set_input (k, node.input (0 ));
930
+ }
931
+ }
932
+ }
933
+ nodesToRemove.push_back (i);
934
+ if (node.op () == " Merge" )
935
+ mergeOpSubgraphNodes.push (i);
936
+ }
937
+ }
938
+
939
+ std::vector<int > numConsumers (net.node_size (), 0 );
940
+ for (int i = 0 ; i < net.node_size (); ++i)
941
+ {
942
+ const tensorflow::NodeDef& node = net.node (i);
943
+ for (int j = 0 ; j < node.input_size (); ++j)
944
+ {
945
+ std::string inpName = node.input (j);
946
+ inpName = inpName.substr (1 + (int )inpName.find (' ^' ), inpName.rfind (' :' ));
947
+ nodesMapIt = nodesMap.find (inpName);
948
+ CV_Assert (nodesMapIt != nodesMap.end ());
949
+ numConsumers[nodesMapIt->second ] += 1 ;
950
+ }
951
+ }
952
+
953
+ // Remove subgraphs of unused nodes which are terminated by Merge nodes.
954
+ while (!mergeOpSubgraphNodes.empty ())
955
+ {
956
+ const tensorflow::NodeDef& node = net.node (mergeOpSubgraphNodes.front ());
957
+ mergeOpSubgraphNodes.pop ();
958
+ for (int i = 0 ; i < node.input_size (); ++i)
959
+ {
960
+ std::string inpName = node.input (i);
961
+ inpName = inpName.substr (1 + (int )inpName.find (' ^' ), inpName.rfind (' :' ));
962
+ nodesMapIt = nodesMap.find (inpName);
963
+ CV_Assert (nodesMapIt != nodesMap.end ());
964
+
965
+ int inpNodeId = nodesMapIt->second ;
966
+ if (numConsumers[inpNodeId] == 1 )
967
+ {
968
+ mergeOpSubgraphNodes.push (inpNodeId);
969
+ nodesToRemove.push_back (inpNodeId);
970
+ }
971
+ else if (numConsumers[inpNodeId] > 0 )
972
+ numConsumers[inpNodeId] -= 1 ;
973
+ }
974
+ }
975
+ std::sort (nodesToRemove.begin (), nodesToRemove.end ());
976
+ for (int i = nodesToRemove.size () - 1 ; i >= 0 ; --i)
977
+ {
978
+ if (nodesToRemove[i] < net.node_size ()) // Ids might be repeated.
979
+ net.mutable_node ()->DeleteSubrange (nodesToRemove[i], 1 );
980
+ }
981
+ }
982
+
983
+
905
984
CV__DNN_EXPERIMENTAL_NS_END
906
985
}} // namespace dnn, namespace cv
907
986
0 commit comments