@@ -223,6 +223,26 @@ class FlattenShapeSubgraph : public Subgraph
223
223
}
224
224
};
225
225
226
+ class FlattenProdSubgraph : public Subgraph
227
+ {
228
+ public:
229
+ FlattenProdSubgraph ()
230
+ {
231
+ int input = addNodeToMatch (" " );
232
+ int shape = addNodeToMatch (" Shape" , input);
233
+ int stack = addNodeToMatch (" Const" );
234
+ int stack_1 = addNodeToMatch (" Const" );
235
+ int stack_2 = addNodeToMatch (" Const" );
236
+ int strided_slice = addNodeToMatch (" StridedSlice" , shape, stack, stack_1, stack_2);
237
+ int prod = addNodeToMatch (" Prod" , strided_slice, addNodeToMatch (" Const" ));
238
+ int shape_pack = addNodeToMatch (" Const" );
239
+ int pack = addNodeToMatch (" Pack" , shape_pack, prod);
240
+ addNodeToMatch (" Reshape" , input, pack);
241
+
242
+ setFusedNode (" Flatten" , input);
243
+ }
244
+ };
245
+
226
246
// K.layers.Softmax
227
247
class SoftMaxKerasSubgraph : public Subgraph
228
248
{
@@ -629,6 +649,36 @@ class KerasMVNSubgraph : public TFSubgraph
629
649
}
630
650
};
631
651
652
+ class PReLUSubgraph : public TFSubgraph
653
+ {
654
+ public:
655
+ PReLUSubgraph (bool negativeScales_) : negativeScales(negativeScales_)
656
+ {
657
+ int input = addNodeToMatch (" " );
658
+ int scales = addNodeToMatch (" Const" );
659
+ int neg = addNodeToMatch (" Neg" , input);
660
+ int relu_neg = addNodeToMatch (" Relu" , neg);
661
+ int finalScales = negativeScales ? addNodeToMatch (" Neg" , scales) : scales;
662
+ int mul = addNodeToMatch (" Mul" , finalScales, relu_neg);
663
+ int relu_pos = addNodeToMatch (" Relu" , input);
664
+ addNodeToMatch (" Add" , relu_pos, mul);
665
+ setFusedNode (" PReLU" , input, scales);
666
+ }
667
+
668
+ virtual void finalize (tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
669
+ std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
670
+ {
671
+ if (!negativeScales)
672
+ {
673
+ Mat scales = getTensorContent (inputNodes[1 ]->attr ().at (" value" ).tensor (), /* copy*/ false );
674
+ scales *= -1 ;
675
+ }
676
+ }
677
+
678
+ private:
679
+ bool negativeScales;
680
+ };
681
+
632
682
void simplifySubgraphs (tensorflow::GraphDef& net)
633
683
{
634
684
std::vector<Ptr<Subgraph> > subgraphs;
@@ -649,6 +699,16 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
649
699
subgraphs.push_back (Ptr<Subgraph>(new SoftMaxSlimV2Subgraph ()));
650
700
subgraphs.push_back (Ptr<Subgraph>(new ReshapeAsShapeSubgraph ()));
651
701
subgraphs.push_back (Ptr<Subgraph>(new KerasMVNSubgraph ()));
702
+ subgraphs.push_back (Ptr<Subgraph>(new PReLUSubgraph (true )));
703
+ subgraphs.push_back (Ptr<Subgraph>(new PReLUSubgraph (false )));
704
+ subgraphs.push_back (Ptr<Subgraph>(new FlattenProdSubgraph ()));
705
+
706
+ for (int i = 0 ; i < net.node_size (); ++i)
707
+ {
708
+ tensorflow::NodeDef* layer = net.mutable_node (i);
709
+ if (layer->op () == " AddV2" )
710
+ layer->set_op (" Add" );
711
+ }
652
712
653
713
simplifySubgraphs (Ptr<ImportGraphWrapper>(new TFGraphWrapper (net)), subgraphs);
654
714
}
0 commit comments