diff --git a/exir/emit/_emitter.py b/exir/emit/_emitter.py index 15e0b23d36f..da3314d3510 100644 --- a/exir/emit/_emitter.py +++ b/exir/emit/_emitter.py @@ -466,6 +466,15 @@ def _tensor_spec_to_evalue( and spec.extra_tensor_info.location == TensorDataLocation.EXTERNAL ): buffer_idx = self.program_state.external_constant_hash.get(hashed, -1) + if buffer_idx != -1: + # Save the constant tag for the external tensor + if constant_tag not in self.program_state.external_constant_map: + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + self.program_state.external_constant_map[constant_tag] = {} + # pyre-ignore Undefined attribute [16]: `Optional` has no attribute `fully_qualified_name`. + self.program_state.external_constant_map[constant_tag][ + spec.extra_tensor_info.fully_qualified_name + ] = buffer_idx else: buffer_idx = self.program_state.cached_spec_hash_values.get(hashed, -1) diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index dcc3544875a..4fc3b4f307e 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -1719,6 +1719,45 @@ def forward(self, x): self.assertEqual(external_map["linear.weight"], 0) self.assertEqual(external_map["linear.bias"], 1) + def test_constant_tagged_tensor_dedup(self) -> None: + class ConstantModule(nn.Module): + def __init__(self): + super().__init__() + constant_value = torch.tensor([1.0, 2.0, 3.0]) + + # Register the same value with two different names as persistent buffers + self.register_buffer( + "constant_a", constant_value.clone(), persistent=True + ) + self.register_buffer( + "constant_b", constant_value.clone(), persistent=True + ) + + def forward(self, x): + return x + self.constant_a + self.constant_b + + model = to_edge( + export(ConstantModule(), (torch.ones(1, 3),), strict=True) + ).to_executorch( + config=ExecutorchBackendConfig( + external_constants=True, + ) + ) + emitter_output = model._emitter_output + # Check that constant_buffer is empty besides the non-constant placeholder 0. + self.assertEqual(len(emitter_output.program.constant_buffer), 1) + # Check that constant weights are in the external constant buffer. + self.assertEqual(len(emitter_output.external_constant_buffer), 1) + # Setting external_constants=True, saves all constants to the key + # '_default_external_constant'. + external_map = emitter_output.external_constant_map[ + "_default_external_constant" + ] + self.assertEqual(len(external_map), 2) + # Confirm that the same tensor is used for both constants. + self.assertEqual(external_map["constant_a"], 0) + self.assertEqual(external_map["constant_b"], 0) + def test_delegate_deduplicate(self) -> None: class SharedModule(torch.nn.Module): def __init__(self):