Skip to content

Commit 573812d

Browse files
committed
Merge pull request opencv#19373 from l-bat:lb/tf_matmul_shared
2 parents c12930c + 38a49f9 commit 573812d

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,18 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
12281228

12291229
int kernel_blob_index = -1;
12301230
const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
1231-
blobFromTensor(kernelTensor, layerParams.blobs[0]);
1232-
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
1231+
const String kernelTensorName = layer.input(kernel_blob_index);
1232+
std::map<String, Mat>::iterator sharedWeightsIt = sharedWeights.find(kernelTensorName);
1233+
if (sharedWeightsIt == sharedWeights.end())
1234+
{
1235+
blobFromTensor(kernelTensor, layerParams.blobs[0]);
1236+
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
1237+
sharedWeights[kernelTensorName] = layerParams.blobs[0];
1238+
}
1239+
else
1240+
{
1241+
layerParams.blobs[0] = sharedWeightsIt->second;
1242+
}
12331243

12341244
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
12351245
Mat data = layerParams.blobs[0].t();

0 commit comments

Comments
 (0)