@@ -249,6 +249,40 @@ class NormalizeSubgraph3 : public NormalizeSubgraphBase
249
249
}
250
250
};
251
251
252
+ class NormalizeSubgraph4 : public NormalizeSubgraphBase
253
+ {
254
+ public:
255
+ NormalizeSubgraph4 () : NormalizeSubgraphBase(1 )
256
+ {
257
+ int input = addNodeToMatch (" " );
258
+ int mul = addNodeToMatch (" Mul" , input, input);
259
+ int sum = addNodeToMatch (" ReduceSum" , mul);
260
+ int eps = addNodeToMatch (" " );
261
+ int max = addNodeToMatch (" Max" , sum, eps);
262
+ int sqrt = addNodeToMatch (" Sqrt" , max);
263
+ int reciprocal = addNodeToMatch (" Reciprocal" , sqrt);
264
+ addNodeToMatch (" Mul" , input, reciprocal);
265
+ setFusedNode (" Normalize" , input);
266
+ }
267
+ };
268
+
269
+ class NormalizeSubgraph5 : public NormalizeSubgraphBase
270
+ {
271
+ public:
272
+ NormalizeSubgraph5 () : NormalizeSubgraphBase(1 )
273
+ {
274
+ int input = addNodeToMatch (" " );
275
+ int mul = addNodeToMatch (" Mul" , input, input);
276
+ int sum = addNodeToMatch (" ReduceSum" , mul);
277
+ int clip = addNodeToMatch (" Clip" , sum);
278
+ int sqrt = addNodeToMatch (" Sqrt" , clip);
279
+ int one = addNodeToMatch (" Constant" );
280
+ int div = addNodeToMatch (" Div" , one, sqrt);
281
+ addNodeToMatch (" Mul" , input, div);
282
+ setFusedNode (" Normalize" , input);
283
+ }
284
+ };
285
+
252
286
class GatherCastSubgraph : public Subgraph
253
287
{
254
288
public:
@@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
526
560
subgraphs.push_back (makePtr<BatchNormalizationSubgraph2>());
527
561
subgraphs.push_back (makePtr<ExpandSubgraph>());
528
562
subgraphs.push_back (makePtr<MishSubgraph>());
563
+ subgraphs.push_back (makePtr<NormalizeSubgraph4>());
564
+ subgraphs.push_back (makePtr<NormalizeSubgraph5>());
529
565
530
566
simplifySubgraphs (Ptr<ImportGraphWrapper>(new ONNXGraphWrapper (net)), subgraphs);
531
567
}
0 commit comments