From c610a30f010a276fca9d2041cb7e485162db3c79 Mon Sep 17 00:00:00 2001 From: Hardik Sharma Date: Mon, 23 Jun 2025 13:07:53 -0700 Subject: [PATCH] Create a MemoryPlanningAlgo class. (#11824) Summary: Refactor our memory planning algos into a `MemoryPlanningAlgo` base class + algo-specific implementation in derived class. This refactor separates common utility functions + constraint handling to the base class. This way, we can add support for hierarchical graphs (using maps), and add more types of constraints (like pinning a tensor to specific dtcm bank) without changing the algo itself. Reviewed By: zonglinpeng Differential Revision: D76954785 --- backends/cadence/aot/TARGETS | 17 + backends/cadence/aot/memory_planning.py | 371 ++++++++----------- backends/cadence/aot/memory_planning_algo.py | 162 ++++++++ 3 files changed, 326 insertions(+), 224 deletions(-) create mode 100644 backends/cadence/aot/memory_planning_algo.py diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index c3ca472147f..a85cc0ca925 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -448,6 +448,22 @@ python_unittest( ], ) +python_library( + name = "memory_planning_algo", + srcs = [ + "memory_planning_algo.py", + ], + deps = [ + ":memory_constraints", + ":pass_utils", + "//executorch/exir:lib", + "//executorch/exir:memory_planning", + "//executorch/exir:tensor", + "//executorch/exir/passes:lib", + "fbsource//third-party/pypi/tabulate:tabulate", + ], +) + python_library( name = "memory_planning", srcs = [ @@ -456,6 +472,7 @@ python_library( deps = [ "fbsource//third-party/pypi/tabulate:tabulate", ":memory_constraints", + ":memory_planning_algo", ":pass_utils", "//caffe2:torch", "//executorch/exir:lib", diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 3c6c518f16a..5a7f6e936fb 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -4,20 +4,19 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import collections import itertools import logging -import math -import typing -from functools import partial -from typing import Iterable, List, Optional, Set, Tuple +from typing import Callable, Iterable, List, Optional, Set, Tuple, TypeAlias import torch -from executorch.backends.cadence.aot.memory_constraints import ( - GenerateMemConstraints, - MemConstraints, +from executorch.backends.cadence.aot.memory_constraints import MemConstraints +from executorch.backends.cadence.aot.memory_planning_algo import ( + get_aligned_offset, + MemoryPlanningAlgo, + MemoryPlanningState, ) from executorch.backends.cadence.aot.utils import MemoryConfig @@ -30,26 +29,6 @@ from torch.fx.passes.infra.pass_base import PassResult -# get num memories indexed from 1..N, compatible with EXIR's spec.mem_id -def get_num_memories(memory_config: MemoryConfig) -> int: - return len(memory_config.memory_sizes) + 1 - - -# memory_space module provides num_memories indexed 0..num_memories-1. -def get_size(memory_config: MemoryConfig, exir_id: int) -> int: - return memory_config.memory_sizes[exir_id - 1] - - -def get_alignment(memory_config: MemoryConfig, exir_id: int) -> int: - # EXIR's spec.mem_id is indexed from 1..N. - assert memory_config.memory_alignments is not None - return memory_config.memory_alignments[exir_id - 1] - - -def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int: - return int(math.ceil(pre_aligned_offset / alignment) * alignment) - - def collect_specs_from_graph_module( graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, @@ -69,198 +48,127 @@ def collect_specs_from_graph_module( ) -# baseline tensor placement algorithm, that greedily tries to place the tensor in -# the fastest memory available -# flake8: noqa 'position_based_greedy_with_hierarchy' is too complex (13) -def position_based_greedy_with_hierarchy( - alignment: int, - specs: Set[TensorSpec], - graph_module: torch.fx.GraphModule, - graph_signature: ExportGraphSignature, - extra_padding: int = 0, - *, - memory_config: MemoryConfig, - mem_constraints: MemConstraints, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, -) -> List[int]: - # We do not use the `alignment` parameter and instead use the per-memory alignment - # constraints from `memory_config`. - del alignment - - num_memories = get_num_memories(memory_config) - bufsizes = [0] * num_memories - allocated_buffers: List[List[TensorSpec]] = [[] for _ in range(num_memories)] +class PositionBasedGreedyWithHierarchy(MemoryPlanningAlgo): + """Greedily place tensor in the fastest memory available.""" - # Generate the memory constraints - GenerateMemConstraints(mem_constraints, additional_constraint_gen_passes)( - graph_module - ) - - def overlap(spec: TensorSpec) -> Optional[TensorSpec]: - for allocated_spec in allocated_buffers[spec.mem_id]: - if Verifier.lifetime_overlap( - spec, allocated_spec - ) and Verifier.storage_overlap(spec, allocated_spec): - return allocated_spec - return None - - def memory_available(spec: TensorSpec) -> bool: - return get_aligned_offset( - spec.mem_offset + spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ) <= get_size(memory_config, spec.mem_id) - - # Iterate over all the specs in sorted order - for spec in sorted( - specs, - key=lambda spec: spec.allocated_memory, - reverse=True, - ): - # Skip allocation memory to any tensor whose spec id is in skip list. - if mem_constraints.skipped_spec(spec): - continue - - for spec.mem_id in range(1, num_memories): - if mem_constraints.is_mem_id_in_blocklist(spec, spec.mem_id): - continue + def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + """ + Greedily place the spec in the first memory that can fit it. + """ + for spec.mem_id in range(1, self.get_num_memories()): spec.mem_offset = 0 - while memory_available(spec) and (overlapped := overlap(spec)): + while self.is_valid_placement(spec) and ( + overlapped := state.get_overlapping_spec(spec) + ): + # Found an overlapping spec, so we need to adjust the offset = end of the overlapping spec + alignment. spec.mem_offset = get_aligned_offset( overlapped.mem_offset + overlapped.allocated_memory, - get_alignment(memory_config, spec.mem_id), + self.get_alignment(spec.mem_id), ) - if memory_available(spec): - allocated_buffers[spec.mem_id].append(spec) - bufsizes[spec.mem_id] = max( - spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] - ) - break - if ( - not allocated_buffers[spec.mem_id] - or allocated_buffers[spec.mem_id][-1] is not spec - ): - raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - - # And now honor the various memory location constraints (i.e., infer the memory - # location of tensors in skip_specs from the constraints) for this spec. - if mem_constraints.relative_loc_constraints_exist(): - mem_constraints.resolve_relative_loc_constraints(spec) - - # At the end, all the keys in relative_loc_constraints should have been visited - # and emptied. - assert not mem_constraints.relative_loc_constraints_exist() - - logging.debug( - f"position based greedy algorithm with hierarchy returns bufsizes: {bufsizes}" - ) - return bufsizes + if self.is_valid_placement(spec): + # Found a valid `spec.mem_offset` which is both valid and has no overlap. + state.place_spec(spec) + break -# Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf -def greedy_by_size_for_offset_calculation_with_hierarchy( - alignment: int, - specs: Set[TensorSpec], - graph_module: torch.fx.GraphModule, - graph_signature: ExportGraphSignature, - extra_padding: int = 0, - *, - memory_config: MemoryConfig, - mem_constraints: MemConstraints, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, -) -> List[int]: - # We do not use the `alignment` parameter and instead use the per-memory alignment - # constraints from `memory_config`. - del alignment + def plan( + self, + specs: Set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + state = prev_state or MemoryPlanningState(self.memory_config) + + # Iterate over all the specs in sorted order + for spec in sorted( + specs, + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + self.plan_spec(spec, state) + if not state.is_placed(spec): + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - num_memories = get_num_memories(memory_config) - bufsizes = [0] * num_memories - allocated_buffers = [[] for _ in range(num_memories)] + return state - # Generate the memory constraints - GenerateMemConstraints(mem_constraints, additional_constraint_gen_passes)( - graph_module - ) - # Iterate over all the specs in sorted order - for spec in sorted( - specs, - key=lambda spec: spec.allocated_memory, - reverse=True, - ): - # Skip allocation memory to any tensor whose spec id is in skip list. - if mem_constraints.skipped_spec(spec): - continue +class GreedyWithHeuristic(MemoryPlanningAlgo): + """Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf.""" - for spec.mem_id in range(1, num_memories): - if mem_constraints.is_mem_id_in_blocklist(spec, spec.mem_id): - continue + def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + """ + Greedily place the spec in the first memory that can fit it. + """ + for spec.mem_id in range(1, self.get_num_memories()): prev_offset, smallest_gap = 0, float("inf") - for allocated_spec in allocated_buffers[spec.mem_id]: - if Verifier.lifetime_overlap(spec, allocated_spec): - if ( - gap := allocated_spec.mem_offset - prev_offset - ) >= spec.allocated_memory and gap < smallest_gap: - smallest_gap = gap - spec.mem_offset = prev_offset - # Note that different from the paper, which updates prev_offset for all - # allocated tensors, we only update tensors with overlapping lifetime. - # Updating prev_offset outside the if statement will include tensors without - # overlapping lifetime, causing unnecessary waste of memory and make the - # calculation of gap incorrect. Moving it out will make the algorithm degenerate - # to the naive one, reusing 0 tensor. The paper may have a typo here. - prev_offset = max( - get_aligned_offset( - allocated_spec.mem_offset + allocated_spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ), - prev_offset, - ) + for allocated_spec in state.allocated_buffers[spec.mem_id]: + if not Verifier.lifetime_overlap(spec, allocated_spec): + continue + + if ( + gap := allocated_spec.mem_offset - prev_offset + ) >= spec.allocated_memory and gap < smallest_gap: + smallest_gap = gap + spec.mem_offset = prev_offset + # Note that different from the paper, which updates prev_offset for all + # allocated tensors, we only update tensors with overlapping lifetime. + # Updating prev_offset outside the if statement will include tensors without + # overlapping lifetime, causing unnecessary waste of memory and make the + # calculation of gap incorrect. Moving it out will make the algorithm degenerate + # to the naive one, reusing 0 tensor. The paper may have a typo here. + prev_offset = max( + get_aligned_offset( + allocated_spec.mem_offset + allocated_spec.allocated_memory, + self.get_alignment(spec.mem_id), + ), + prev_offset, + ) if spec.mem_offset is None: if get_aligned_offset( prev_offset + spec.allocated_memory, - get_alignment(memory_config, spec.mem_id), - ) > get_size(memory_config, spec.mem_id): + self.get_alignment(spec.mem_id), + ) > self.get_size(spec.mem_id): continue else: spec.mem_offset = prev_offset - bufsizes[spec.mem_id] = max( - spec.mem_offset + spec.allocated_memory, bufsizes[spec.mem_id] - ) - allocated_buffers[spec.mem_id].append(spec) - allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) + + state.place_spec(spec) # A data structure used for maintaining the tensor order # by offset, named ordered_allocated_ids in the paper + state.allocated_buffers[spec.mem_id].sort(key=lambda spec: spec.mem_offset) break - if spec not in allocated_buffers[spec.mem_id]: - raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - # And now honor the various memory location constraints (i.e., infer the memory - # location of tensors in skip_specs from the constraints) for this spec. - if mem_constraints.relative_loc_constraints_exist(): - mem_constraints.resolve_relative_loc_constraints(spec) + def plan( + self, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + """Plan memory allocation for the given tensor specs.""" + # We do not use the `alignment` parameter and instead use the per-memory alignment + # constraints from `memory_config`. + + state = prev_state or MemoryPlanningState(self.memory_config) + + # Iterate over all the specs in sorted order + for spec in sorted( + specs, + key=lambda spec: spec.allocated_memory, + reverse=True, + ): + self.plan_spec(spec, state) + if not state.is_placed(spec): + raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - # At the end, all the keys in relative_loc_constraints should have been visited - # and emptied. - assert not mem_constraints.relative_loc_constraints_exist() + logging.debug( + f"greedy by size for offset calculation with hierarchy returns bufsizes: {state.bufsizes}" + ) - logging.debug( - f"greedy by size for offset calculation with hierarchy returns bufsizes: {bufsizes}" - ) - return bufsizes + return state def find_peak_memory_usages_per_memory( @@ -436,6 +344,12 @@ def print_memory_planning_info( ) +ConstraintGenPassType: TypeAlias = Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], +] + + class CadenceMemoryPlanning: def __init__( self, @@ -444,28 +358,48 @@ def __init__( mem_algo: int, alloc_graph_input: bool = True, alloc_graph_output: bool = True, - additional_constraint_gen_passes: Optional[ - List[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ] - ] = None, + additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]] = None, ) -> None: - self._init_mem_algos() - self.memory_config = memory_config self.opt_level = opt_level - self.mem_algo = mem_algo self.alloc_graph_input = alloc_graph_input self.alloc_graph_output = alloc_graph_output - self.additional_constraint_gen_passes = additional_constraint_gen_passes - def _init_mem_algos(self) -> None: - self.available_mem_algos = [ - position_based_greedy_with_hierarchy, - greedy_by_size_for_offset_calculation_with_hierarchy, + self.algo: MemoryPlanningAlgo = self.get_mem_algos( + memory_config, + opt_level, + alloc_graph_input, + alloc_graph_output, + additional_constraint_gen_passes, + )[mem_algo] + + @staticmethod + def get_mem_algos( + memory_config: MemoryConfig, + opt_level: int, + alloc_graph_input: bool, + alloc_graph_output: bool, + additional_constraint_gen_passes: Optional[list[ConstraintGenPassType]], + ) -> list[MemoryPlanningAlgo]: + return [ + PositionBasedGreedyWithHierarchy( + memory_config, + MemConstraints( + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + ), + additional_constraint_gen_passes, + ), + GreedyWithHeuristic( + memory_config, + MemConstraints( + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + ), + additional_constraint_gen_passes, + ), ] def __call__( @@ -479,22 +413,11 @@ def run( graph_module: torch.fx.GraphModule, graph_signature: Optional[ExportGraphSignature] = None, ) -> PassResult: - mem_constraints = MemConstraints( - opt_level=self.opt_level, - alloc_graph_input=self.alloc_graph_input, - alloc_graph_output=self.alloc_graph_output, - ) - algo = partial( - self.available_mem_algos[self.mem_algo], - memory_config=self.memory_config, - mem_constraints=mem_constraints, - additional_constraint_gen_passes=self.additional_constraint_gen_passes, - ) # Create the memory planning pass. We allocate memory for input # (output) tensors if alloc_graph_input (alloc_graph_output) is # True. mem_planning = MemoryPlanningPass( - algo, + self.algo, allow_lifetime_and_storage_overlap=(self.opt_level >= 2), alloc_graph_input=self.alloc_graph_input, alloc_graph_output=self.alloc_graph_output, diff --git a/backends/cadence/aot/memory_planning_algo.py b/backends/cadence/aot/memory_planning_algo.py new file mode 100644 index 00000000000..5b67cc6c5fd --- /dev/null +++ b/backends/cadence/aot/memory_planning_algo.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# pyre-strict + +import logging +import math +from abc import ABC, abstractmethod +from typing import Callable, Optional + +import torch +from executorch.backends.cadence.aot.memory_constraints import ( + GenerateMemConstraints, + MemConstraints, +) +from executorch.backends.cadence.aot.utils import MemoryConfig +from executorch.exir.memory_planning import Verifier +from executorch.exir.pass_base import PassResult +from executorch.exir.tensor import TensorSpec +from torch.export.exported_program import ExportGraphSignature + + +def get_aligned_offset(pre_aligned_offset: int, alignment: int) -> int: + return int(math.ceil(pre_aligned_offset / alignment) * alignment) + + +class MemoryPlanningState: + def __init__(self, memory_config: MemoryConfig) -> None: + self.num_memories: int = len(memory_config.memory_sizes) + 1 + alignment = memory_config.memory_alignments + assert alignment is not None + assert len(alignment) == self.num_memories - 1 + self.alignment: list[int] = [1] + alignment + # TODO: Maybe keep this sorted with heapq? + self.allocated_buffers: list[list[TensorSpec]] = [ + [] for _ in range(self.num_memories) + ] + self.bufsizes: list[int] = [0] * self.num_memories + + def place_spec(self, spec: TensorSpec) -> None: + """Place the spec at the given memory and offset.""" + assert self.get_overlapping_spec(spec) is None + self.allocated_buffers[spec.mem_id].append(spec) + self.bufsizes[spec.mem_id] = max( + self.bufsizes[spec.mem_id], + get_aligned_offset( + spec.mem_offset + spec.allocated_memory, self.alignment[spec.mem_id] + ), + ) + + def get_overlapping_spec(self, spec: TensorSpec) -> Optional[TensorSpec]: + """Get the overlapping spec for the given spec.""" + for allocated_spec in self.allocated_buffers[spec.mem_id]: + if Verifier.lifetime_overlap( + spec, allocated_spec + ) and Verifier.storage_overlap(spec, allocated_spec): + return allocated_spec + return None + + def is_placed(self, spec: TensorSpec) -> bool: + """Check if the spec is placed.""" + return spec in self.allocated_buffers[spec.mem_id] + + +class MemoryPlanningAlgo(ABC): + """Callable memory planning algorithm interface.""" + + def __init__( + self, + memory_config: MemoryConfig, + placement_constraints: MemConstraints, + additional_constraint_gen_passes: Optional[ + list[ + Callable[ + [MemConstraints], + Callable[[torch.fx.GraphModule], Optional[PassResult]], + ] + ] + ] = None, + ) -> None: + self.memory_config = memory_config + self.placement_constraints = placement_constraints + self.additional_constraint_gen_passes = additional_constraint_gen_passes + + def get_num_memories(self) -> int: + """Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id.""" + return len(self.memory_config.memory_sizes) + 1 + + def get_size(self, exir_id: int) -> int: + # memory_space module provides num_memories indexed 0..num_memories-1. + return self.memory_config.memory_sizes[exir_id - 1] + + def get_alignment(self, exir_id: int) -> int: + # EXIR's spec.mem_id is indexed from 1..N. + assert self.memory_config.memory_alignments is not None + return self.memory_config.memory_alignments[exir_id - 1] + + def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None: + """Populate the constraints for the memory planning algorithm.""" + GenerateMemConstraints( + mem_constraints=self.placement_constraints, + additional_constraint_gen_passes=self.additional_constraint_gen_passes, + )(graph_module) + + def is_valid_placement(self, spec: TensorSpec) -> bool: + return get_aligned_offset( + spec.mem_offset + spec.allocated_memory, + self.get_alignment(spec.mem_id), + ) <= self.get_size(spec.mem_id) + + @abstractmethod + def plan( + self, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + prev_state: Optional[MemoryPlanningState] = None, + ) -> MemoryPlanningState: + """Plan memory allocation for the given tensor specs.""" + pass + + def __call__( + self, + alignment: int, + specs: set[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + ) -> list[int]: + """Callable interface for ET memory planning.""" + self.populate_constraints(graph_module) + + # First plan the memory allocation for specs without relative constraints. + specs_without_relative_constraints = set( + filter( + lambda spec: not self.placement_constraints.skipped_spec(spec) + and not self.placement_constraints.is_mem_id_in_blocklist( + spec, spec.mem_id + ), + specs, + ) + ) + + # Call memory planning to get bufsizes. + state = self.plan( + specs_without_relative_constraints, + graph_module, + graph_signature, + extra_padding, + ) + + for spec in specs_without_relative_constraints: + # And now honor the various memory location constraints (i.e., infer the memory + # location of tensors in skip_specs from the constraints) for this spec. + self.placement_constraints.resolve_relative_loc_constraints(spec) + + # At the end, all the keys in relative_loc_constraints should have been visited + # and emptied. + assert not self.placement_constraints.relative_loc_constraints_exist() + + logging.debug(f"Memory planning algo found bufsizes: {state.bufsizes}") + return state.bufsizes