Skip to content

Commit 339b963

Browse files
committed
Fix MatMul and Add axes
1 parent f3cebb3 commit 339b963

File tree

2 files changed

+27
-0
lines changed

2 files changed

+27
-0
lines changed

modules/dnn/src/onnx/onnx_importer.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -641,6 +641,17 @@ void ONNXImporter::populateNet(Net dstNet)
641641
{
642642
layerParams.type = "Scale";
643643
layerParams.set("bias_term", true);
644+
int axis = 1;
645+
for (int i = 0; i < graph_proto.initializer_size(); i++)
646+
{
647+
opencv_onnx::TensorProto tensor_proto = graph_proto.initializer(i);
648+
if (tensor_proto.name() == node_proto.input(const_blob_id))
649+
{
650+
axis = inpShape.size() - tensor_proto.dims_size();
651+
break;
652+
}
653+
}
654+
layerParams.set("axis", axis);
644655
blob = blob.reshape(1, 1);
645656
layerParams.blobs.push_back((isSub ? -1 : 1) * blob);
646657
}
@@ -911,13 +922,20 @@ void ONNXImporter::populateNet(Net dstNet)
911922
CV_Assert(node_proto.input_size() == 2);
912923
layerParams.type = "InnerProduct";
913924
layerParams.set("bias_term", false);
925+
CV_Assert(constBlobs.find(node_proto.input(0)) == constBlobs.end());
926+
int firstInpDims = outShapes[node_proto.input(0)].size();
927+
int secondInpDims;
914928

915929
if (constBlobs.find(node_proto.input(1)) != constBlobs.end())
916930
{
917931
Mat blob = getBlob(node_proto, constBlobs, 1);
932+
secondInpDims = blob.dims;
918933
layerParams.blobs.push_back(blob.t());
919934
layerParams.set("num_output", layerParams.blobs[0].size[0]);
935+
} else {
936+
secondInpDims = outShapes[node_proto.input(1)].size();
920937
}
938+
layerParams.set("axis", firstInpDims - secondInpDims + 1);
921939
}
922940
else if (layer_type == "Mul" || layer_type == "Div")
923941
{

modules/dnn/test/test_onnx_importer.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,15 @@ TEST_P(Test_ONNX_layers, MatMul)
404404
testONNXModels("matmul_4d");
405405
}
406406

407+
TEST_P(Test_ONNX_layers, MatMulAdd)
408+
{
409+
if (backend == DNN_BACKEND_INFERENCE_ENGINE_NN_BUILDER_2019)
410+
applyTestTag(CV_TEST_TAG_DNN_SKIP_IE_NN_BUILDER);
411+
if (backend == DNN_BACKEND_OPENCV && target == DNN_TARGET_OPENCL_FP16)
412+
applyTestTag(CV_TEST_TAG_DNN_SKIP_OPENCL_FP16);
413+
testONNXModels("matmul_add");
414+
}
415+
407416
TEST_P(Test_ONNX_layers, Expand)
408417
{
409418
testONNXModels("expand_batch");

0 commit comments

Comments
 (0)