Skip to content

Commit ba19c75

Browse files
Properly set mutable buffer lifespans
Differential Revision: D77618047 Pull Request resolved: #12182
1 parent b342f83 commit ba19c75

File tree

2 files changed

+54
-4
lines changed

2 files changed

+54
-4
lines changed

exir/memory_planning.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,11 @@ def _is_inplace_node(node: torch.fx.Node) -> bool:
301301

302302

303303
def update_tensor_lifetime(
304-
node: torch.fx.Node, spec: TensorSpec, node_idx: int
304+
node: torch.fx.Node,
305+
spec: TensorSpec,
306+
node_idx: int,
307+
max_node_idx: int,
308+
gs: Optional[ExportGraphSignature] = None,
305309
) -> None:
306310
r"""
307311
Update the lifetime of the tensor to cover node_idx. A tensor's lifetime
@@ -317,7 +321,12 @@ def update_tensor_lifetime(
317321
start = 0
318322
else:
319323
start = node_idx if start is None or start > node_idx else start
320-
end = node_idx if end is None or end < node_idx else end
324+
325+
if node.op == "placeholder" and _is_mutable_buffer(node, gs):
326+
# mutable buffers are never freed
327+
end = max_node_idx
328+
else:
329+
end = node_idx if end is None or end < node_idx else end
321330
spec.lifetime = [start, end]
322331

323332

@@ -497,7 +506,7 @@ def update_all_tensors_lifetime(
497506
Set the lifetime for all the tensors encountered in the Fx graph.
498507
"""
499508
specs = set()
500-
509+
max_node_idx = len(graph_module.graph.nodes) - 1
501510
for node_idx, node in enumerate(graph_module.graph.nodes):
502511
for spec in collect_specs_from_nodes(
503512
filter_nodes(itertools.chain([node], node.args, node.kwargs.values())),
@@ -509,7 +518,7 @@ def update_all_tensors_lifetime(
509518
do_assertion=False,
510519
ignore_dynamic_unbound_tensor=False,
511520
):
512-
update_tensor_lifetime(node, spec, node_idx)
521+
update_tensor_lifetime(node, spec, node_idx, max_node_idx, graph_signature)
513522
specs.add(spec)
514523
return specs
515524

exir/tests/test_memory_planning.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,47 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
664664
.val.allocation_info.memory_offset_high,
665665
)
666666

667+
def test_mutable_buffers_infinite_lifespan(self) -> None:
668+
class Simple(torch.nn.Module):
669+
def __init__(self) -> None:
670+
super().__init__()
671+
self.register_buffer("state", torch.zeros(1))
672+
673+
def forward(self, x: torch.Tensor) -> torch.Tensor:
674+
self.state.index_put_(
675+
[
676+
torch.tensor([0]),
677+
],
678+
x,
679+
)
680+
y = x + self.state
681+
z = x * y
682+
return z
683+
684+
model = Simple()
685+
inputs = (torch.ones(1),)
686+
687+
et = to_edge(export(model, inputs, strict=True)).to_executorch(
688+
ExecutorchBackendConfig(
689+
emit_mutable_buffer_names=True, run_reinplace_pass=True
690+
)
691+
)
692+
693+
serialized_state = et.executorch_program.execution_plan[0].values[0].val
694+
self.assertEqual(
695+
serialized_state.extra_tensor_info.fully_qualified_name, "state"
696+
)
697+
memory_base = serialized_state.allocation_info.memory_offset_low
698+
memory_size = memory_base + 4 # 4 bytes for a single float
699+
for value in et.executorch_program.execution_plan[0].values[1:]:
700+
val = value.val
701+
if hasattr(val, "allocation_info") and val.allocation_info is not None:
702+
not_overlapping = (
703+
val.allocation_info.memory_offset_low < memory_base
704+
or val.allocation_info.memory_offset_low >= memory_size
705+
)
706+
self.assertTrue(not_overlapping)
707+
667708
def test_constants_not_memory_planned(self) -> None:
668709
class Simple(torch.nn.Module):
669710
def __init__(self) -> None:

0 commit comments

Comments
 (0)