Skip to content

Improve memory planning for submodule hierarchies. #11860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 134 additions & 58 deletions exir/memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,20 @@
import itertools
import logging
import operator
import typing
from collections import defaultdict
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
from typing import (
Any,
Callable,
cast,
Dict,
Iterable,
List,
Optional,
Set,
Tuple,
Union,
)

import torch
from executorch.exir import memory
Expand Down Expand Up @@ -949,7 +959,7 @@ def _allocate_buf(bufsizes: List[int], mem_id: int, allocated: int) -> int:
bufsizes = getattr(graph_module, "input_mem_buffer_sizes", None)
if bufsizes is None:
bufsizes = [0, 0]
bufsizes = typing.cast(List[int], bufsizes)
bufsizes = cast(List[int], bufsizes)

for spec in specs:
spec_alloc_result = naive_result.spec_dict.get(spec, SpecAllocResult(0, 0, 0))
Expand Down Expand Up @@ -1051,92 +1061,158 @@ def insert_calls_to_free(
graph_module.recompile()


def _merge_bufsizes(bufsizes: list[int], new_bufsizes: list[int]) -> list[int]:
"""Combine two buffer size lists."""
if len(bufsizes) < len(new_bufsizes):
bufsizes.extend([0] * (len(new_bufsizes) - len(bufsizes)))
for i in range(len(new_bufsizes)):
bufsizes[i] = max(bufsizes[i], new_bufsizes[i])
return bufsizes


def _handle_submodule(
algo: Callable[..., list[int]],
parent_graph_module: torch.fx.GraphModule,
alignment: int,
submodule_node: torch.fx.Node,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = False,
) -> list[int]:
"""Apply algo to nodes in a submodule of the graph module."""
assert submodule_node.op == "get_attr"
submodule = getattr(parent_graph_module, submodule_node.target)

logging.debug(f"Planning memory for submodule {submodule_node.name}...")
bufsizes = apply_algo(
algo,
submodule,
alignment,
graph_signature,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=True,
)
submodule.meta.update({"non_const_buffer_sizes": bufsizes})
logging.debug(f"Buffer sizes for submodule {submodule_node.name}: {bufsizes}")
return bufsizes


def _apply_algo_to_submodules(
algo: Callable[..., list[int]],
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
) -> list[int]:
"""Apply algo to map/cond/while nodes in the graph module.

This method will popuate graph_module.meta["non_const_buffer_sizes"] for
all submodules and return a bufsizes list that is the maximum size of all
buffers.
"""

# Bufsizes for submodules.
bufsizes: list[int] = []

def _handle(
submodule_node: torch.fx.Node,
alloc_graph_input: bool = False,
) -> None:
current_bufsizes = _handle_submodule(
algo,
graph_module,
alignment,
submodule_node,
graph_signature,
alloc_graph_input=alloc_graph_input,
)
nonlocal bufsizes
_merge_bufsizes(bufsizes, current_bufsizes)

for cond_node in get_cond_nodes(graph_module):
_handle(cast(torch.fx.Node, cond_node.args[1]))
_handle(cast(torch.fx.Node, cond_node.args[2]))

for while_node in get_while_nodes(graph_module):
_handle(cast(torch.fx.Node, while_node.args[0]))
_handle(cast(torch.fx.Node, while_node.args[1]))

for map_node in get_map_nodes(graph_module):
_handle(cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True)

# TODO: We can handle delegates the same way as map/cond/while.
# Maybe populate the graph_module.meta["non_const_buffer_sizes"] for delegates.

return bufsizes


def apply_algo(
algo: Callable[
...,
List[int],
],
algo: Callable[..., list[int]],
graph_module: torch.fx.GraphModule,
alignment: int,
graph_signature: Optional[ExportGraphSignature] = None,
alloc_graph_input: bool = True,
alloc_graph_output: bool = True,
alloc_mutable_buffers: bool = True,
) -> List[int]:
) -> list[int]:
"""
Recursively apply algo to graph_module and its submodules for control flow.

Quite naively right now since it does not take the following optimizations
into considerating:
1. for conditional structure, true branch and false true does not overlap
in lifetime and can share tensor storage
2. tensors inside a submodule (e.g. true branch) has opportunities to share
storage with tensors in the outer module.
TODO: make these optimizations once we have some baseline working.
Algo implementation should handle one of two meta entries for submodules:
1. input_mem_buffer_sizes: List of int offset bytes. Memory allocated by
`algo` should start at the offset specified by this list;
OR
2. non_const_buffer_sizes: List of bufsizes for planned memory in submodule.
`algo` should reserve the space specified by this list for the lifetime
of the submodule node (e.g. cond, while, map).

TODO: Missing optimizations:
1. To handle maps, we set `alloc_graph_input=True`, which allocates
appropriate space for mapped arg but ends up allocating extra space for
`operand` arg. The memory for operands is unused.
"""
# Extract the nodes and their lifespans from the graph_module
# Difficult to just filter the list of specs returned by this due to
# how we flag trainable weights.
_ = update_all_tensors_lifetime(graph_module, graph_signature)

# Filter specs based on alloc_graph_input and alloc_graph_output
specs = collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
do_assertion=False,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
ignore_mutable_buffers=not alloc_mutable_buffers,
specs = set(
collect_specs_from_nodes(
graph_module.graph.nodes,
graph_signature,
do_assertion=False,
ignore_graph_input=not alloc_graph_input,
ignore_graph_output=not alloc_graph_output,
ignore_mutable_buffers=not alloc_mutable_buffers,
)
)

# Get temporary specs for submodules to set aside space during execution
# of submodules.
submodule_bufsizes = _apply_algo_to_submodules(
algo, graph_module, alignment, graph_signature
)

# Update `input_mem_buffer_sizes` in graph_module. This will allow existing
# algos to work using `input_mem_buffer_sizes` or use
# `non_const_buffer_sizes` directly.
# pyre-ignore[16]: `torch.fx.GraphModule` has no attribute `input_mem_buffer_sizes`.
graph_module.input_mem_buffer_sizes = submodule_bufsizes

# Get extra padding for XNNPACK if needed
extra_padding = 0
if _contains_xnnpack_delegate(graph_module):
extra_padding = 64

# Pass the filtered specs to the algorithm
bufsizes: List[int] = algo(
bufsizes: list[int] = algo(
alignment,
specs,
graph_module,
graph_signature,
extra_padding,
)

insert_calls_to_free(graph_module, set(specs))

def handle_submodule(
submodule_nd: torch.fx.Node, alloc_graph_input: bool = False
) -> None:
nonlocal bufsizes
assert submodule_nd.op == "get_attr"
submodule = getattr(graph_module, submodule_nd.target)
# memory planning for submodule need to be aware of the amount of
# buffer already allocated.
submodule.input_mem_buffer_sizes = bufsizes

bufsizes = apply_algo(
algo,
submodule,
alignment,
graph_signature,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=True,
)
submodule.meta.update({"non_const_buffer_sizes": bufsizes})

for cond_node in get_cond_nodes(graph_module):
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[1]))
handle_submodule(typing.cast(torch.fx.Node, cond_node.args[2]))

for while_node in get_while_nodes(graph_module):
handle_submodule(typing.cast(torch.fx.Node, while_node.args[0]))
handle_submodule(typing.cast(torch.fx.Node, while_node.args[1]))
# TODO: Add test coverage for map operator once dynamo tracing is
# fully supported for this. T142287208
for map_node in get_map_nodes(graph_module):
handle_submodule(
typing.cast(torch.fx.Node, map_node.args[0]), alloc_graph_input=True
)
insert_calls_to_free(graph_module, specs)

graph_module.meta.update({"non_const_buffer_sizes": bufsizes})
return bufsizes
100 changes: 98 additions & 2 deletions exir/tests/test_memory_planning.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
ToOutVarPass,
)
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
from executorch.exir.tensor import TensorSpec
from functorch.experimental.control_flow import map as torch_map
from parameterized import parameterized

from torch import nn
Expand All @@ -55,6 +57,7 @@
from torch.export.exported_program import ExportGraphSignature
from torch.fx import Graph, GraphModule, Node
from torch.nn import functional as F
from torch.utils import _pytree as pytree

torch.ops.load_library("//executorch/kernels/portable:custom_ops_generated_lib")

Expand Down Expand Up @@ -420,13 +423,13 @@ def test_graph_input_output(self) -> None:
alloc_graph_output,
alloc_mutable_buffers,
) in itertools.product([True, False], [True, False], [True, False]):
case = maketest(
test = maketest(
ModelWithDifferentTensorSizes,
alloc_graph_input=alloc_graph_input,
alloc_graph_output=alloc_graph_output,
alloc_mutable_buffer=alloc_mutable_buffers,
)
case(self)
test(self)


class TestVerifier(unittest.TestCase):
Expand Down Expand Up @@ -788,3 +791,96 @@ def forward(self, input, label):
.val.allocation_info, # pyright: ignore
None,
)

def _get_specs(gm: torch.fx.GraphModule) -> list[TensorSpec]:
return list(
filter(
None,
pytree.tree_flatten(
pytree.tree_map_only(
torch.fx.Node,
lambda n: n.meta.get("spec", None),
list(gm.graph.nodes),
)
)[0],
)
)

class TestMap(unittest.TestCase):
class MapModel(torch.nn.Module):
def __init__(self) -> None:
super().__init__()

def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
# Use actual torch.map function for memory planning testing
def add_fn(a: torch.Tensor, b: torch.Tensor) -> torch.Tensor:
return a + b

# Use torch.map to apply function over first dimension
# pyre-ignore[6]: For 3rd argument expected `TypeVarTuple` but got `Tensor`.
map_output = torch_map(add_fn, x, y)

return map_output + y

def get_random_inputs(self) -> Tuple[torch.Tensor, ...]:
return (torch.randn(5, 3), torch.randn(3))

def test_map(self) -> None:
"""Test memory planning for torch.map operations."""

eager_module = self.MapModel().eval()
inputs = eager_module.get_random_inputs()

# Export and convert to edge
graph_module = (
to_edge(export(eager_module, inputs, strict=True))
.exported_program()
.graph_module
)

# Apply memory planning.
mem_algo = MemoryPlanningAlgorithmSuite(algo_list=[naive, greedy])
graph_module = PassManager(
passes=[
SpecPropPass(),
ToOutVarPass(),
],
)(graph_module).graph_module
mem_planning_pass = MemoryPlanningPass(
mem_algo,
alloc_graph_input=True,
alloc_graph_output=True,
alloc_mutable_buffers=True,
)
graph_module = mem_planning_pass.run(graph_module).graph_module

# Verify memory planning results
verifier = Verifier(
graph_module,
alloc_graph_input=True,
alloc_graph_output=True,
alloc_mutable_buffers=True,
)
verifier.verify_graph_input_output()
verifier.verify_storage_reuse(allow_lifetime_and_storage_overlap=False)

map_node = graph_module.graph.find_nodes(
op="call_function", target=torch.ops.higher_order.map_impl
)[0]
map_fn_node = map_node.args[0]
self.assertEqual(map_fn_node.op, "get_attr")
map_fn = getattr(graph_module, map_fn_node.target)


map_lifetime = map_node.meta.get("spec", None)[0].lifetime[0]

# Check that there is no storage overlap between nodes of the outer program and submodule of map.
for outer_spec in _get_specs(graph_module):
for inner_spec in _get_specs(map_fn):
self.assertFalse(
verifier.has_overlap(
outer_spec.lifetime, [map_lifetime, map_lifetime]
)
and (verifier.storage_overlap(outer_spec, inner_spec)),
f"Outer spec {outer_spec.shape=} {outer_spec.dtype=} {outer_spec.lifetime=} and inner spec {inner_spec} have storage overlap",
)
Loading