Skip to content

Commit 5da4bb7

Browse files
committed
Merge pull request opencv#16983 from dkurt:dnn_tf_prelu
2 parents dc1b1f2 + 25ec4ce commit 5da4bb7

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

modules/dnn/src/tensorflow/tf_graph_simplifier.cpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,26 @@ class FlattenShapeSubgraph : public Subgraph
223223
}
224224
};
225225

226+
class FlattenProdSubgraph : public Subgraph
227+
{
228+
public:
229+
FlattenProdSubgraph()
230+
{
231+
int input = addNodeToMatch("");
232+
int shape = addNodeToMatch("Shape", input);
233+
int stack = addNodeToMatch("Const");
234+
int stack_1 = addNodeToMatch("Const");
235+
int stack_2 = addNodeToMatch("Const");
236+
int strided_slice = addNodeToMatch("StridedSlice", shape, stack, stack_1, stack_2);
237+
int prod = addNodeToMatch("Prod", strided_slice, addNodeToMatch("Const"));
238+
int shape_pack = addNodeToMatch("Const");
239+
int pack = addNodeToMatch("Pack", shape_pack, prod);
240+
addNodeToMatch("Reshape", input, pack);
241+
242+
setFusedNode("Flatten", input);
243+
}
244+
};
245+
226246
// K.layers.Softmax
227247
class SoftMaxKerasSubgraph : public Subgraph
228248
{
@@ -629,6 +649,36 @@ class KerasMVNSubgraph : public TFSubgraph
629649
}
630650
};
631651

652+
class PReLUSubgraph : public TFSubgraph
653+
{
654+
public:
655+
PReLUSubgraph(bool negativeScales_) : negativeScales(negativeScales_)
656+
{
657+
int input = addNodeToMatch("");
658+
int scales = addNodeToMatch("Const");
659+
int neg = addNodeToMatch("Neg", input);
660+
int relu_neg = addNodeToMatch("Relu", neg);
661+
int finalScales = negativeScales ? addNodeToMatch("Neg", scales) : scales;
662+
int mul = addNodeToMatch("Mul", finalScales, relu_neg);
663+
int relu_pos = addNodeToMatch("Relu", input);
664+
addNodeToMatch("Add", relu_pos, mul);
665+
setFusedNode("PReLU", input, scales);
666+
}
667+
668+
virtual void finalize(tensorflow::GraphDef&, tensorflow::NodeDef* fusedNode,
669+
std::vector<tensorflow::NodeDef*>& inputNodes) CV_OVERRIDE
670+
{
671+
if (!negativeScales)
672+
{
673+
Mat scales = getTensorContent(inputNodes[1]->attr().at("value").tensor(), /*copy*/false);
674+
scales *= -1;
675+
}
676+
}
677+
678+
private:
679+
bool negativeScales;
680+
};
681+
632682
void simplifySubgraphs(tensorflow::GraphDef& net)
633683
{
634684
std::vector<Ptr<Subgraph> > subgraphs;
@@ -649,6 +699,16 @@ void simplifySubgraphs(tensorflow::GraphDef& net)
649699
subgraphs.push_back(Ptr<Subgraph>(new SoftMaxSlimV2Subgraph()));
650700
subgraphs.push_back(Ptr<Subgraph>(new ReshapeAsShapeSubgraph()));
651701
subgraphs.push_back(Ptr<Subgraph>(new KerasMVNSubgraph()));
702+
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(true)));
703+
subgraphs.push_back(Ptr<Subgraph>(new PReLUSubgraph(false)));
704+
subgraphs.push_back(Ptr<Subgraph>(new FlattenProdSubgraph()));
705+
706+
for (int i = 0; i < net.node_size(); ++i)
707+
{
708+
tensorflow::NodeDef* layer = net.mutable_node(i);
709+
if (layer->op() == "AddV2")
710+
layer->set_op("Add");
711+
}
652712

653713
simplifySubgraphs(Ptr<ImportGraphWrapper>(new TFGraphWrapper(net)), subgraphs);
654714
}

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,7 @@ void TFImporter::populateNet(Net dstNet)
12311231
// Only NHWC <-> NCHW permutations are allowed. OpenCV is always
12321232
// keep NCHW layout this way.
12331233
int inpLayout = getDataLayout(layer.input(0), data_layouts);
1234+
std::string type = "Identity";
12341235
if (inpLayout == DATA_LAYOUT_NHWC)
12351236
{
12361237
if (permData[0] == 0 && permData[1] == 3 && permData[2] == 1 && permData[3] == 2)
@@ -1245,6 +1246,15 @@ void TFImporter::populateNet(Net dstNet)
12451246
// in OpenCV: NCHW->NCHW
12461247
data_layouts[name] = DATA_LAYOUT_NHWC;
12471248
}
1249+
else if (permData[0] == 0 && permData[1] == 3 && permData[2] == 2 && permData[3] == 1)
1250+
{
1251+
// in TensorFlow: NHWC->NCWH
1252+
// in OpenCV: NCHW->NCWH
1253+
int permData[] = {0, 1, 3, 2};
1254+
layerParams.set("order", DictValue::arrayInt<int*>(permData, perm.total()));
1255+
data_layouts[name] = DATA_LAYOUT_NCHW; // we keep track NCHW because channels position only matters
1256+
type = "Permute";
1257+
}
12481258
else
12491259
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
12501260
}
@@ -1265,7 +1275,7 @@ void TFImporter::populateNet(Net dstNet)
12651275
else
12661276
CV_Error(Error::StsParseError, "Only NHWC <-> NCHW permutations are allowed.");
12671277
}
1268-
int id = dstNet.addLayer(name, "Identity", layerParams);
1278+
int id = dstNet.addLayer(name, type, layerParams);
12691279
layer_id[name] = id;
12701280
connect(layer_id, dstNet, parsePin(layer.input(0)), id, 0);
12711281
}

modules/dnn/test/test_tf_importer.cpp

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -956,11 +956,25 @@ TEST_P(Test_TensorFlow_layers, resize_bilinear)
956956
runTensorFlowNet("resize_bilinear_factor");
957957
}
958958

959-
TEST_P(Test_TensorFlow_layers, tf2_keras)
959+
TEST_P(Test_TensorFlow_layers, tf2_dense)
960960
{
961961
runTensorFlowNet("tf2_dense");
962962
}
963963

964+
TEST_P(Test_TensorFlow_layers, tf2_prelu)
965+
{
966+
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
967+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
968+
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NGRAPH)
969+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NGRAPH);
970+
runTensorFlowNet("tf2_prelu");
971+
}
972+
973+
TEST_P(Test_TensorFlow_layers, tf2_permute_nhwc_ncwh)
974+
{
975+
runTensorFlowNet("tf2_permute_nhwc_ncwh");
976+
}
977+
964978
TEST_P(Test_TensorFlow_layers, squeeze)
965979
{
966980
#if defined(INF_ENGINE_RELEASE)

0 commit comments

Comments
 (0)