diff --git a/exir/memory_planning.py b/exir/memory_planning.py index 83598940882..98aa269ee01 100644 --- a/exir/memory_planning.py +++ b/exir/memory_planning.py @@ -10,10 +10,20 @@ import itertools import logging import operator -import typing from collections import defaultdict from dataclasses import dataclass, field -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import ( + Any, + Callable, + cast, + Dict, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import torch from executorch.exir import memory @@ -949,7 +959,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int: bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None) if bufsizes is None: bufsizes = [0, 0] - bufsizes = typing.cast(List[int], bufsizes) + bufsizes = cast(List[int], bufsizes) for spec in specs: spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0)) @@ -1051,50 +1061,150 @@ def insert_calls_to_free( graph_module.recompile() +def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]: + """Combine two buffer size lists.""" + if len(bufsizes) < len(new_bufsizes): + bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes))) + for i in range(len(new_bufsizes)): + bufsizes[i] = max(bufsizes[i], new_bufsizes[i]) + return bufsizes + + +def _handle_submodule( + algo: Callable[..., list[int]], + parent_graph_module: torch.fx.GraphModule, + alignment: int, + submodule_node: torch.fx.Node, + graph_signature: Optional[ExportGraphSignature] = None, + alloc_graph_input: bool = False, +) -> list[int]: + """Apply algo to nodes in a submodule of the graph module.""" + assert submodule_node.op == "get_attr" + submodule = getattr(parent_graph_module, submodule_node.target) + + logging.debug(f"Planning memory for submodule {submodule_node.name}...") + bufsizes = apply_algo( + algo, + submodule, + alignment, + graph_signature, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=True, + ) + submodule.meta.update({"non_const_buffer_sizes": bufsizes}) + logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}") + return bufsizes + + +def _apply_algo_to_submodules( + algo: Callable[..., list[int]], + graph_module: torch.fx.GraphModule, + alignment: int, + graph_signature: Optional[ExportGraphSignature] = None, +) -> list[int]: + """Apply algo to map/cond/while nodes in the graph module. + + This method will popuate graph_module.meta["non_const_buffer_sizes"] for + all submodules and return a bufsizes list that is the maximum size of all + buffers. + """ + + # Bufsizes for submodules. + bufsizes: list[int] = [] + + def _handle( + submodule_node: torch.fx.Node, + alloc_graph_input: bool = False, + ) -> None: + current_bufsizes = _handle_submodule( + algo, + graph_module, + alignment, + submodule_node, + graph_signature, + alloc_graph_input=alloc_graph_input, + ) + nonlocal bufsizes + _merge_bufsizes(bufsizes, current_bufsizes) + + for cond_node in get_cond_nodes(graph_module): + _handle(cast(torch.fx.Node, cond_node.args[1])) + _handle(cast(torch.fx.Node, cond_node.args[2])) + + for while_node in get_while_nodes(graph_module): + _handle(cast(torch.fx.Node, while_node.args[0])) + _handle(cast(torch.fx.Node, while_node.args[1])) + + for map_node in get_map_nodes(graph_module): + _handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True) + + # TODO: We can handle delegates the same way as map/cond/while. + # Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates. + + return bufsizes + + def apply_algo( - algo: Callable[ - ..., - List[int], - ], + algo: Callable[..., list[int]], graph_module: torch.fx.GraphModule, alignment: int, graph_signature: Optional[ExportGraphSignature] = None, alloc_graph_input: bool = True, alloc_graph_output: bool = True, alloc_mutable_buffers: bool = True, -) -> List[int]: +) -> list[int]: """ Recursively apply algo to graph_module and its submodules for control flow. - Quite naively right now since it does not take the following optimizations - into considerating: - 1. for conditional structure, true branch and false true does not overlap - in lifetime and can share tensor storage - 2. tensors inside a submodule (e.g. true branch) has opportunities to share - storage with tensors in the outer module. - TODO: make these optimizations once we have some baseline working. + Algo implementation should handle one of two meta entries for submodules: + 1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by + `algo` should start at the offset specified by this list; + OR + 2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule. + `algo` should reserve the space specified by this list for the lifetime + of the submodule node (e.g. cond, while, map). + + TODO: Missing optimizations: + 1. To handle maps, we set `alloc_graph_input=True`, which allocates + appropriate space for mapped arg but ends up allocating extra space for + `operand` arg. The memory for operands is unused. """ # Extract the nodes and their lifespans from the graph_module # Difficult to just filter the list of specs returned by this due to # how we flag trainable weights. _ = update_all_tensors_lifetime(graph_module, graph_signature) + # Filter specs based on alloc_graph_input and alloc_graph_output - specs = collect_specs_from_nodes( - graph_module.graph.nodes, - graph_signature, - do_assertion=False, - ignore_graph_input=not alloc_graph_input, - ignore_graph_output=not alloc_graph_output, - ignore_mutable_buffers=not alloc_mutable_buffers, + specs = set( + collect_specs_from_nodes( + graph_module.graph.nodes, + graph_signature, + do_assertion=False, + ignore_graph_input=not alloc_graph_input, + ignore_graph_output=not alloc_graph_output, + ignore_mutable_buffers=not alloc_mutable_buffers, + ) ) + # Get temporary specs for submodules to set aside space during execution + # of submodules. + submodule_bufsizes = _apply_algo_to_submodules( + algo, graph_module, alignment, graph_signature + ) + + # Update `input_mem_buffer_sizes` in graph_module. This will allow existing + # algos to work using `input_mem_buffer_sizes` or use + # `non_const_buffer_sizes` directly. + # pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`. + graph_module.input_mem_buffer_sizes = submodule_bufsizes + # Get extra padding for XNNPACK if needed extra_padding = 0 if _contains_xnnpack_delegate(graph_module): extra_padding = 64 # Pass the filtered specs to the algorithm - bufsizes: List[int] = algo( + bufsizes: list[int] = algo( alignment, specs, graph_module, @@ -1102,41 +1212,7 @@ def apply_algo( extra_padding, ) - insert_calls_to_free(graph_module, set(specs)) - - def handle_submodule( - submodule_nd: torch.fx.Node, alloc_graph_input: bool = False - ) -> None: - nonlocal bufsizes - assert submodule_nd.op == "get_attr" - submodule = getattr(graph_module, submodule_nd.target) - # memory planning for submodule need to be aware of the amount of - # buffer already allocated. - submodule.input_mem_buffer_sizes = bufsizes - - bufsizes = apply_algo( - algo, - submodule, - alignment, - graph_signature, - alloc_graph_input=alloc_graph_input, - alloc_graph_output=True, - ) - submodule.meta.update({"non_const_buffer_sizes": bufsizes}) - - for cond_node in get_cond_nodes(graph_module): - handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1])) - handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2])) - - for while_node in get_while_nodes(graph_module): - handle_submodule(typing.cast(torch.fx.Node, while_node.args[0])) - handle_submodule(typing.cast(torch.fx.Node, while_node.args[1])) - # TODO: Add test coverage for map operator once dynamo tracing is - # fully supported for this. T142287208 - for map_node in get_map_nodes(graph_module): - handle_submodule( - typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True - ) + insert_calls_to_free(graph_module, specs) graph_module.meta.update({"non_const_buffer_sizes": bufsizes}) return bufsizes diff --git a/exir/tests/test_memory_planning.py b/exir/tests/test_memory_planning.py index b87ae2dfb58..da4afbaa7fc 100644 --- a/exir/tests/test_memory_planning.py +++ b/exir/tests/test_memory_planning.py @@ -32,6 +32,8 @@ ToOutVarPass, ) from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass +from executorch.exir.tensor import TensorSpec +from functorch.experimental.control_flow import map as torch_map from parameterized import parameterized from torch import nn @@ -55,6 +57,7 @@ from torch.export.exported_program import ExportGraphSignature from torch.fx import Graph, GraphModule, Node from torch.nn import functional as F +from torch.utils import _pytree as pytree torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib") @@ -420,13 +423,13 @@ def test_graph_input_output(self) -> None: alloc_graph_output, alloc_mutable_buffers, ) in itertools.product([True, False], [True, False], [True, False]): - case = maketest( + test = maketest( ModelWithDifferentTensorSizes, alloc_graph_input=alloc_graph_input, alloc_graph_output=alloc_graph_output, alloc_mutable_buffer=alloc_mutable_buffers, ) - case(self) + test(self) class TestVerifier(unittest.TestCase): @@ -788,3 +791,96 @@ def forward(self, input, label): .val.allocation_info, # pyright: ignore None, ) + +def _get_specs(gm: torch.fx.GraphModule) -> list[TensorSpec]: + return list( + filter( + None, + pytree.tree_flatten( + pytree.tree_map_only( + torch.fx.Node, + lambda n: n.meta.get("spec", None), + list(gm.graph.nodes), + ) + )[0], + ) + ) + +class TestMap(unittest.TestCase): + class MapModel(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + # Use actual torch.map function for memory planning testing + def add_fn(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor: + return a + b + + # Use torch.map to apply function over first dimension + # pyre-ignore[6]: For 3rd argument expected `TypeVarTuple` but got `Tensor`. + map_output = torch_map(add_fn, x, y) + + return map_output + y + + def get_random_inputs(self) -> Tuple[torch.Tensor, ...]: + return (torch.randn(5, 3), torch.randn(3)) + + def test_map(self) -> None: + """Test memory planning for torch.map operations.""" + + eager_module = self.MapModel().eval() + inputs = eager_module.get_random_inputs() + + # Export and convert to edge + graph_module = ( + to_edge(export(eager_module, inputs, strict=True)) + .exported_program() + .graph_module + ) + + # Apply memory planning. + mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[naive, greedy]) + graph_module = PassManager( + passes=[ + SpecPropPass(), + ToOutVarPass(), + ], + )(graph_module).graph_module + mem_planning_pass = MemoryPlanningPass( + mem_algo, + alloc_graph_input=True, + alloc_graph_output=True, + alloc_mutable_buffers=True, + ) + graph_module = mem_planning_pass.run(graph_module).graph_module + + # Verify memory planning results + verifier = Verifier( + graph_module, + alloc_graph_input=True, + alloc_graph_output=True, + alloc_mutable_buffers=True, + ) + verifier.verify_graph_input_output() + verifier.verify_storage_reuse(allow_lifetime_and_storage_overlap=False) + + map_node = graph_module.graph.find_nodes( + op="call_function", target=torch.ops.higher_order.map_impl + )[0] + map_fn_node = map_node.args[0] + self.assertEqual(map_fn_node.op, "get_attr") + map_fn = getattr(graph_module, map_fn_node.target) + + + map_lifetime = map_node.meta.get("spec", None)[0].lifetime[0] + + # Check that there is no storage overlap between nodes of the outer program and submodule of map. + for outer_spec in _get_specs(graph_module): + for inner_spec in _get_specs(map_fn): + self.assertFalse( + verifier.has_overlap( + outer_spec.lifetime, [map_lifetime, map_lifetime] + ) + and (verifier.storage_overlap(outer_spec, inner_spec)), + f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap", + )