Skip to content

Commit 0a86ddc

Browse files
committed
Merge pull request opencv#19435 from l-bat:lb/onnx_normalize
2 parents 23734af + 68eb54d commit 0a86ddc

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

modules/dnn/src/onnx/onnx_graph_simplifier.cpp

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,40 @@ class NormalizeSubgraph3 : public NormalizeSubgraphBase
249249
}
250250
};
251251

252+
class NormalizeSubgraph4 : public NormalizeSubgraphBase
253+
{
254+
public:
255+
NormalizeSubgraph4() : NormalizeSubgraphBase(1)
256+
{
257+
int input = addNodeToMatch("");
258+
int mul = addNodeToMatch("Mul", input, input);
259+
int sum = addNodeToMatch("ReduceSum", mul);
260+
int eps = addNodeToMatch("");
261+
int max = addNodeToMatch("Max", sum, eps);
262+
int sqrt = addNodeToMatch("Sqrt", max);
263+
int reciprocal = addNodeToMatch("Reciprocal", sqrt);
264+
addNodeToMatch("Mul", input, reciprocal);
265+
setFusedNode("Normalize", input);
266+
}
267+
};
268+
269+
class NormalizeSubgraph5 : public NormalizeSubgraphBase
270+
{
271+
public:
272+
NormalizeSubgraph5() : NormalizeSubgraphBase(1)
273+
{
274+
int input = addNodeToMatch("");
275+
int mul = addNodeToMatch("Mul", input, input);
276+
int sum = addNodeToMatch("ReduceSum", mul);
277+
int clip = addNodeToMatch("Clip", sum);
278+
int sqrt = addNodeToMatch("Sqrt", clip);
279+
int one = addNodeToMatch("Constant");
280+
int div = addNodeToMatch("Div", one, sqrt);
281+
addNodeToMatch("Mul", input, div);
282+
setFusedNode("Normalize", input);
283+
}
284+
};
285+
252286
class GatherCastSubgraph : public Subgraph
253287
{
254288
public:
@@ -526,6 +560,8 @@ void simplifySubgraphs(opencv_onnx::GraphProto& net)
526560
subgraphs.push_back(makePtr<BatchNormalizationSubgraph2>());
527561
subgraphs.push_back(makePtr<ExpandSubgraph>());
528562
subgraphs.push_back(makePtr<MishSubgraph>());
563+
subgraphs.push_back(makePtr<NormalizeSubgraph4>());
564+
subgraphs.push_back(makePtr<NormalizeSubgraph5>());
529565

530566
simplifySubgraphs(Ptr<ImportGraphWrapper>(new ONNXGraphWrapper(net)), subgraphs);
531567
}

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,11 @@ TEST_P(Test_ONNX_layers, BatchNormalizationSubgraph)
403403
testONNXModels("batch_norm_subgraph");
404404
}
405405

406+
TEST_P(Test_ONNX_layers, NormalizeFusionSubgraph)
407+
{
408+
testONNXModels("normalize_fusion");
409+
}
410+
406411
TEST_P(Test_ONNX_layers, Transpose)
407412
{
408413
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)

0 commit comments

Comments
 (0)