Skip to content

Commit 38a49f9

Browse files
committed
Added shared weights for MatMul
1 parent 0f968e3 commit 38a49f9

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

modules/dnn/src/tensorflow/tf_importer.cpp

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1228,8 +1228,18 @@ void TFImporter::parseNode(const tensorflow::NodeDef& layer_)
12281228

12291229
int kernel_blob_index = -1;
12301230
const tensorflow::TensorProto& kernelTensor = getConstBlob(layer, value_id, -1, &kernel_blob_index);
1231-
blobFromTensor(kernelTensor, layerParams.blobs[0]);
1232-
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
1231+
const String kernelTensorName = layer.input(kernel_blob_index);
1232+
std::map<String, Mat>::iterator sharedWeightsIt = sharedWeights.find(kernelTensorName);
1233+
if (sharedWeightsIt == sharedWeights.end())
1234+
{
1235+
blobFromTensor(kernelTensor, layerParams.blobs[0]);
1236+
releaseTensor(const_cast<tensorflow::TensorProto*>(&kernelTensor));
1237+
sharedWeights[kernelTensorName] = layerParams.blobs[0];
1238+
}
1239+
else
1240+
{
1241+
layerParams.blobs[0] = sharedWeightsIt->second;
1242+
}
12331243

12341244
if (kernel_blob_index == 1) { // In this case output is computed by x*W formula - W should be transposed
12351245
Mat data = layerParams.blobs[0].t();

0 commit comments

Comments
 (0)