diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index cdbfcef36f5..d93ea9315ed 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -184,6 +184,21 @@ python_library( ], ) +python_library( + name = "program_builder", + srcs = [ + "program_builder.py", + ], + typing = True, + deps = [ + ":graph_builder", + "fbcode//caffe2:torch", + "fbcode//executorch/exir:lib", + "fbcode//executorch/exir:pass_base", + "fbcode//executorch/exir/verification:verifier", + ], +) + python_library( name = "fuse_ops", srcs = [ @@ -508,6 +523,7 @@ python_unittest( ":typing_stubs", ":ops_registrations", ":pass_utils", + ":program_builder", "//caffe2:torch", "//executorch/exir:memory", "//executorch/exir/dialects:lib", diff --git a/backends/cadence/aot/memory_constraints.py b/backends/cadence/aot/memory_constraints.py index 62eeb80fd65..8e784cd2779 100644 --- a/backends/cadence/aot/memory_constraints.py +++ b/backends/cadence/aot/memory_constraints.py @@ -4,11 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +# pyre-strict import logging import math -import typing from collections import defaultdict from dataclasses import dataclass from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias @@ -28,19 +27,38 @@ @dataclass(frozen=True) -class SourceInfo: +class RelativePlacementConstraint: """Information of source node and offset used for views.""" source: torch.fx.Node offset: int = 0 +@dataclass(frozen=True) +class AbsolutePlacementConstraint: + """Information on placement constraint memory id and offset.""" + + pinned_memory_id: int + + # If offset is None, then the tensor can be placed anywhere in the memory id. + offset: Optional[int] = None + + class MemConstraints: """ This class contains all the tensor placement constraints that we create during memory planning. - Any tensor whose placement is derived off another tensor via a constraint - is not included in memory planning, and is marked as skipped. + + We have two types of placement constraints: + 1. Relative placement constraints: These are constraints that specify the + relative placement of a tensor with respect to another tensor. For + example, when slice dim is 0, slice output can be placed relative to + their inputs and the op can be replaced with a nop. + 2. Absolute placement constraints: These are constraints that specify the + absolute placement of a tensor either in a specific memory id, or both + a specific memory id and offset. For example, for operators that require + a specific memory id + offset for we can use this constraint to specify + location of inputs/outputs or even temporary buffers. """ def __init__( @@ -62,29 +80,38 @@ def __init__( # A set of tensor spec ids that must be skipped during memory allocation. # The exact mem_id and offset of the skipped tensors will be computed from # the constraints. - self._source_node: dict[int, SourceInfo] = {} + self._relative_placement_constraint: dict[int, RelativePlacementConstraint] = {} # A map from `id(TensorSpec)` to a set of mem_ids that cannot be used for # allocating the tensor. self._mem_id_blocklist: dict[int, set[int]] = {} - def get_source_info(self, node: torch.fx.Node) -> Optional[SourceInfo]: + # A map from `id(TensorSpec)` to a AbsolutePlacementConstraint that specifies mem_id and optionally exact offset. + self._absolute_placement_constraints: dict[int, AbsolutePlacementConstraint] = ( + {} + ) + + def get_relative_placement_source( + self, node: torch.fx.Node + ) -> Optional[RelativePlacementConstraint]: spec = node.meta.get("spec") spec_id = id(spec) - if spec_id not in self._source_node: + if spec_id not in self._relative_placement_constraint: return None - return self._source_node[spec_id] + return self._relative_placement_constraint[spec_id] - def set_source_info( - self, dependent: torch.fx.Node, source_info: SourceInfo + def set_relative_placement_constraint( + self, + dependent: torch.fx.Node, + placement_constraint: RelativePlacementConstraint, ) -> None: dependent_spec = dependent.meta.get("spec") spec_id = id(dependent_spec) - self._source_node[spec_id] = source_info - if self.is_memory_planned(source_info.source): + self._relative_placement_constraint[spec_id] = placement_constraint + if self.is_memory_planned(placement_constraint.source): # Only add dependent nodes if source node needs memory planning. self.unresolved_loc_constraints[ - id(source_info.source.meta.get("spec")) + id(placement_constraint.source.meta.get("spec")) ].add(dependent) def add_mem_id_to_blocklist(self, spec: TensorSpec, mem_id: int) -> None: @@ -111,7 +138,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool: node --> view --> relu (or some other op that can be in-place) """ - if node_source_info := self.get_source_info(node): + if node_source_info := self.get_relative_placement_source(node): node_spec = node.meta.get("spec") node_source_spec = node_source_info.source.meta.get("spec") return ( @@ -121,7 +148,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool: and self.is_alias_of(node_source_info.source, other_node) ) - if self.get_source_info(other_node) is not None: + if self.get_relative_placement_source(other_node) is not None: return self.is_alias_of(other_node, node) return node == other_node @@ -132,14 +159,14 @@ def relative_loc_constraints_exist(self) -> bool: # Return true if the spec is marked as skipped def skipped_spec(self, spec: TensorSpec) -> bool: - return id(spec) in self._source_node + return id(spec) in self._relative_placement_constraint def is_memory_planned( self, node: torch.fx.Node, ) -> bool: """Return true if the node is either (1) a parameter, or (2) a placeholder.""" - if (source_info := self.get_source_info(node)) is not None: + if (source_info := self.get_relative_placement_source(node)) is not None: # If node has relative placement constraints, then check the source. return self.is_memory_planned(source_info.source) # Check if any node is a param. @@ -183,7 +210,7 @@ def resolve_relative_loc_constraints(self, spec: TensorSpec) -> None: assert isinstance(spec, TensorSpec) for dependent_node in self.unresolved_loc_constraints[spec_id]: - source_info = self.get_source_info(dependent_node) + source_info = self.get_relative_placement_source(dependent_node) assert source_info is not None dependent_spec = cast(TensorSpec, dependent_node.meta.get("spec")) dependent_spec.mem_id = spec.mem_id @@ -202,19 +229,21 @@ def update_children_nodes(self, node: torch.fx.Node, update_lifetime: bool) -> N children_nodes = self.unresolved_loc_constraints[id(node.meta.get("spec"))] self.unresolved_loc_constraints.pop(id(node.meta.get("spec"))) - source_info = self.get_source_info(node) + source_info = self.get_relative_placement_source(node) assert source_info is not None for child_node in children_nodes: - child_info = self._source_node.pop(id(child_node.meta.get("spec"))) - self.generate_location_constraint( + child_info = self._relative_placement_constraint.pop( + id(child_node.meta.get("spec")) + ) + self.add_relative_placement_constraint( source_info.source, child_node, offset=source_info.offset + child_info.offset, update_lifetime=update_lifetime, ) - def generate_location_constraint( + def add_relative_placement_constraint( self, source: torch.fx.Node, dependent: torch.fx.Node, @@ -230,29 +259,26 @@ def generate_location_constraint( logging.debug(f"Adding constraint {dependent} = {source} + {offset=}") # Assert that both source and dependent node are tensors. - if (info := self.get_source_info(source)) is not None: - return self.generate_location_constraint( - info.source, dependent, offset + info.offset, update_lifetime - ) + if (info := self.get_relative_placement_source(source)) is not None: + source = info.source + offset += info.offset - if (info := self.get_source_info(dependent)) is not None: + if (info := self.get_relative_placement_source(dependent)) is not None: # Dependent node can only be an alias (same size, offset = 0). assert self.is_alias_of( info.source, dependent ), f"Multiple constraints for allocation of {dependent}. Previous constraint: {info} new constraint: {source=} {offset=}" - return self.generate_location_constraint( - source, info.source, offset, update_lifetime=update_lifetime - ) + dependent = info.source # Add the dependent spec to skip list. Its memory offset will be computed # after the output tensor is allocated space. - source_info = SourceInfo(source=source, offset=offset) - self.set_source_info(dependent, source_info) + source_info = RelativePlacementConstraint(source=source, offset=offset) + self.set_relative_placement_constraint(dependent, source_info) # If update_lifetime is True, take a union of the lifetime of representaitve # and dependent tensors; this will become the new lifetime of source tensor. + dependent_spec = dependent.meta.get("spec") if update_lifetime: - dependent_spec = dependent.meta.get("spec") source_spec = source.meta.get("spec") source.meta.get("spec").lifetime = [ min(source_spec.lifetime[0], dependent_spec.lifetime[0]), @@ -261,6 +287,49 @@ def generate_location_constraint( self.update_children_nodes(dependent, update_lifetime) + abs_constraint = self.get_absolute_placement_constraint(dependent_spec) + if abs_constraint is None: + return + + # Dependent node has an absolute placement constraint. + # If the offset is not 0, then we cannot add a relative placement constraint. + if not self.is_alias_of(dependent, source): + raise RuntimeError( + f"Cannot add relative placement constraint for {dependent} with non-zero offset {offset} when it has an absolute placement constraint {abs_constraint}" + ) + + # Add the absolute placement constraint to the source node. + self._absolute_placement_constraints.pop(id(dependent_spec)) + self.add_absolute_placement_constraint( + source, abs_constraint.pinned_memory_id, abs_constraint.offset + ) + + def add_absolute_placement_constraint( + self, node: torch.fx.Node, pinned_memory_id: int, offset: Optional[int] = None + ) -> None: + """Add a memory pinning constraint for `node` to `mem_id`.""" + logging.debug( + f"Adding memory pinning constraint {node=} = {pinned_memory_id=} at {offset=}" + ) + source_node: torch.fx.Node = node + if (info := self.get_relative_placement_source(node)) is not None: + assert self.is_alias_of(info.source, node) + logging.debug( + f"Setting {node} to {info.source} + {offset=}. Pinned to {pinned_memory_id=}" + ) + source_node = info.source + self._absolute_placement_constraints[id(source_node.meta.get("spec"))] = ( + AbsolutePlacementConstraint( + pinned_memory_id=pinned_memory_id, offset=offset + ) + ) + + def get_absolute_placement_constraint( + self, spec: TensorSpec + ) -> Optional[AbsolutePlacementConstraint]: + """Return true if `node` has an absolute placement constraint.""" + return self._absolute_placement_constraints.get(id(spec), None) + def get_relative_offsets_of_cat_tensors( cat_tensors: Sequence[torch.fx.Node], @@ -342,7 +411,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: def is_slice_view(self, node: torch.fx.Node) -> bool: """Return if `node` has constraints and is not an alias of another node.""" - if (source_info := self.constraint.get_source_info(node)) is not None: + if ( + source_info := self.constraint.get_relative_placement_source(node) + ) is not None: return not self.constraint.is_alias_of(source_info.source, node) return False @@ -426,7 +497,9 @@ def is_removable_cat_op( return True # Currently the contiguity constraints are generated by cat operator. - def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule): + def compute_cat_contiguity_constraints( + self, graph_module: torch.fx.GraphModule + ) -> None: for node in graph_module.graph.nodes: # Only compute relative constraints if the cat node can be replaced with # its nop version @@ -448,7 +521,9 @@ def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule) # Get the relative offsets for each tensor to be concatenated. relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors) for arg, offset in zip(cat_tensors, relative_offsets): - self.constraint.generate_location_constraint(node, arg, offset=offset) + self.constraint.add_relative_placement_constraint( + node, arg, offset=offset + ) # Update the lifetimes of the args to that of the output tensor, so # that they don't get overwritten @@ -474,7 +549,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]: for node in graph_module.graph.nodes: if node.op != "call_function" or node.target != memory.view: continue - self.constraint.generate_location_constraint(node.args[0], node) + self.constraint.add_relative_placement_constraint(node.args[0], node) @register_cadence_pass(CadencePassAttribute(opt_level=2)) @@ -544,7 +619,7 @@ def removable_slice_or_select_op( # the input and output tensor. def compute_slice_and_select_loc_constraints( self, graph_module: torch.fx.GraphModule - ): + ) -> None: for node in graph_module.graph.nodes: # Only compute relative constraints if the slice node can be # replaced with its nop version @@ -563,7 +638,7 @@ def compute_slice_and_select_loc_constraints( # And now generate location constraint between input and output # tensors of slice node arg = node.args[0] - self.constraint.generate_location_constraint( + self.constraint.add_relative_placement_constraint( arg, node, offset=offset, @@ -607,12 +682,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult: filtered_passes = [ mcg_pass(self.mem_constraints) for mcg_pass in cast( - list[ - typing.Callable[ - [MemConstraints], - typing.Callable[[torch.fx.GraphModule], Optional[PassResult]], - ] - ], + list[ConstraintsGenPass], # pyre-ignore[6]: Incompatible parameter type. list(filter(pass_filter, constraint_gen_passes)), ) diff --git a/backends/cadence/aot/memory_planning.py b/backends/cadence/aot/memory_planning.py index 8baaaa203d0..0634af6ea61 100644 --- a/backends/cadence/aot/memory_planning.py +++ b/backends/cadence/aot/memory_planning.py @@ -9,7 +9,7 @@ import collections import itertools import logging -from typing import Iterable, List, Optional, Sequence, Set, Tuple +from typing import Iterable, Optional, Sequence import torch from executorch.backends.cadence.aot.memory_constraints import MemConstraints @@ -52,13 +52,18 @@ def collect_specs_from_graph_module( class PositionBasedGreedyWithHierarchy(MemoryPlanningAlgo): """Greedily place tensor in the fastest memory available.""" - def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + def plan_spec( + self, + spec: TensorSpec, + state: MemoryPlanningState, + placement_constraints: MemConstraints, + ) -> None: """ Greedily place the spec in the first memory that can fit it. """ for spec.mem_id in range(1, self.get_num_memories()): spec.mem_offset = 0 - while self.is_valid_placement(spec) and ( + while self.is_valid_placement(spec, placement_constraints) and ( overlapped := state.get_overlapping_spec(spec) ): # Found an overlapping spec, so we need to adjust the offset = end of the overlapping spec + alignment. @@ -67,20 +72,20 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: self.get_alignment(spec.mem_id), ) - if self.is_valid_placement(spec): + if self.is_valid_placement(spec, placement_constraints): # Found a valid `spec.mem_offset` which is both valid and has no overlap. state.place_spec(spec) break def plan( self, - specs: Set[TensorSpec], + specs: Iterable[TensorSpec], graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, extra_padding: int = 0, - prev_state: Optional[MemoryPlanningState] = None, - ) -> MemoryPlanningState: - state = prev_state or MemoryPlanningState(self.memory_config) + ) -> None: # Iterate over all the specs in sorted order for spec in sorted( @@ -88,17 +93,20 @@ def plan( key=lambda spec: spec.allocated_memory, reverse=True, ): - self.plan_spec(spec, state) + self.plan_spec(spec, state, placement_constraints) if not state.is_placed(spec): raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") - return state - class GreedyWithHeuristic(MemoryPlanningAlgo): """Greedy tensor placement with the heuristics from arxiv.org/pdf/2001.03288.pdf.""" - def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: + def plan_spec( + self, + spec: TensorSpec, + state: MemoryPlanningState, + placement_constraints: MemConstraints, + ) -> None: """ Greedily place the spec in the first memory that can fit it. """ @@ -128,7 +136,7 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: ) if spec.mem_offset is None: spec.mem_offset = prev_offset - if not self.is_valid_placement(spec): + if not self.is_valid_placement(spec, placement_constraints): spec.mem_offset = None continue else: @@ -142,25 +150,24 @@ def plan_spec(self, spec: TensorSpec, state: MemoryPlanningState) -> None: def plan( self, - specs: set[TensorSpec], + specs: Iterable[TensorSpec], graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, extra_padding: int = 0, - prev_state: Optional[MemoryPlanningState] = None, - ) -> MemoryPlanningState: + ) -> None: """Plan memory allocation for the given tensor specs.""" # We do not use the `alignment` parameter and instead use the per-memory alignment # constraints from `memory_config`. - state = prev_state or MemoryPlanningState(self.memory_config) - # Iterate over all the specs in sorted order for spec in sorted( specs, key=lambda spec: spec.allocated_memory, reverse=True, ): - self.plan_spec(spec, state) + self.plan_spec(spec, state, placement_constraints) if not state.is_placed(spec): raise MemoryError(f"Cannot fit {spec} in any memory hierarchy") @@ -168,8 +175,6 @@ def plan( f"greedy by size for offset calculation with hierarchy returns bufsizes: {state.bufsizes}" ) - return state - def find_peak_memory_usages_per_memory( graph_module: torch.fx.GraphModule, @@ -177,7 +182,7 @@ def find_peak_memory_usages_per_memory( alloc_graph_input: bool, alloc_graph_output: bool, mem_constraints: Optional[MemConstraints] = None, -) -> List[int]: +) -> list[int]: """ Given a GraphModule with a memory plan, find the peak memory usages for each memory in the memory hierarchy. @@ -216,7 +221,7 @@ def find_peak_memory_usage( alloc_graph_input: bool, alloc_graph_output: bool, mem_constraints: Optional[MemConstraints] = None, -) -> Tuple[int, int]: +) -> tuple[int, int]: """ Given a GraphModule with a memory plan, find the peak usage over time across all memories in the memory hierarchy. The resulting peak memory usage should be: @@ -377,22 +382,18 @@ def get_mem_algos( ) -> list[MemoryPlanningAlgo]: return [ PositionBasedGreedyWithHierarchy( - memory_config, - MemConstraints( - opt_level=opt_level, - alloc_graph_input=alloc_graph_input, - alloc_graph_output=alloc_graph_output, - ), - additional_constraint_gen_passes, + memory_config=memory_config, + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + additional_constraint_gen_passes=additional_constraint_gen_passes, ), GreedyWithHeuristic( - memory_config, - MemConstraints( - opt_level=opt_level, - alloc_graph_input=alloc_graph_input, - alloc_graph_output=alloc_graph_output, - ), - additional_constraint_gen_passes, + memory_config=memory_config, + opt_level=opt_level, + alloc_graph_input=alloc_graph_input, + alloc_graph_output=alloc_graph_output, + additional_constraint_gen_passes=additional_constraint_gen_passes, ), ] diff --git a/backends/cadence/aot/memory_planning_algo.py b/backends/cadence/aot/memory_planning_algo.py index ffff2e6aab1..8193b73c9fd 100644 --- a/backends/cadence/aot/memory_planning_algo.py +++ b/backends/cadence/aot/memory_planning_algo.py @@ -5,10 +5,12 @@ import logging import math from abc import ABC, abstractmethod -from typing import Optional, Sequence +from contextlib import contextmanager +from typing import Iterable, Iterator, Optional, Sequence import torch from executorch.backends.cadence.aot.memory_constraints import ( + AbsolutePlacementConstraint, ConstraintsGenPass, GenerateMemConstraints, MemConstraints, @@ -38,6 +40,7 @@ def __init__(self, memory_config: MemoryConfig) -> None: def place_spec(self, spec: TensorSpec) -> None: """Place the spec at the given memory and offset.""" + logging.debug(f"Placing spec {spec}: {spec.mem_id=}, {spec.mem_offset=}") assert self.get_overlapping_spec(spec) is None self.allocated_buffers[spec.mem_id].append(spec) self.bufsizes[spec.mem_id] = max( @@ -58,7 +61,22 @@ def get_overlapping_spec(self, spec: TensorSpec) -> Optional[TensorSpec]: def is_placed(self, spec: TensorSpec) -> bool: """Check if the spec is placed.""" - return spec in self.allocated_buffers[spec.mem_id] + return spec.mem_id is not None and spec in self.allocated_buffers[spec.mem_id] + + def __str__(self) -> str: + allocated_buffers_str = "" + for i, specs in enumerate(self.allocated_buffers): + allocated_buffers_str += ( + f"Memory {i}: " + + ", ".join( + [ + f"<{s.shape=} {s.mem_id=} {s.mem_offset=} {s.allocated_memory=}>" + for s in specs + ] + ) + + "\n" + ) + return f"MemoryPlanningState(bufsizes={self.bufsizes}, allocated_buffers={allocated_buffers_str})" class MemoryPlanningAlgo(ABC): @@ -67,14 +85,19 @@ class MemoryPlanningAlgo(ABC): def __init__( self, memory_config: MemoryConfig, - placement_constraints: MemConstraints, + opt_level: int = 1, + alloc_graph_input: bool = True, + alloc_graph_output: bool = True, additional_constraint_gen_passes: Optional[Sequence[ConstraintsGenPass]] = None, ) -> None: self.memory_config: MemoryConfig = memory_config - self.placement_constraints: MemConstraints = placement_constraints self.additional_constraint_gen_passes: Optional[ Sequence[ConstraintsGenPass] ] = additional_constraint_gen_passes + self.opt_level: int = opt_level + self.alloc_graph_input: bool = alloc_graph_input + self.alloc_graph_output: bool = alloc_graph_output + self.memory_id_is_valid: list[bool] = [True] * self.get_num_memories() def get_num_memories(self) -> int: """Get num memories indexed from 1..N, compatible with EXIR's spec.mem_id.""" @@ -89,70 +112,230 @@ def get_alignment(self, exir_id: int) -> int: assert self.memory_config.memory_alignments is not None return self.memory_config.memory_alignments[exir_id - 1] - def populate_constraints(self, graph_module: torch.fx.GraphModule) -> None: + def populate_constraints( + self, graph_module: torch.fx.GraphModule + ) -> tuple[MemoryPlanningState, MemConstraints]: """Populate the constraints for the memory planning algorithm.""" + state = MemoryPlanningState(self.memory_config) + placement_constraints = MemConstraints( + self.opt_level, self.alloc_graph_input, self.alloc_graph_output + ) GenerateMemConstraints( - mem_constraints=self.placement_constraints, + mem_constraints=placement_constraints, additional_constraint_gen_passes=self.additional_constraint_gen_passes, )(graph_module) + return state, placement_constraints - def is_valid_placement(self, spec: TensorSpec) -> bool: + def is_valid_placement( + self, spec: TensorSpec, placement_constraints: MemConstraints + ) -> bool: """Returns true if the spec can be placed at the given memory id.""" end_of_allocation = get_aligned_offset( spec.mem_offset + spec.allocated_memory, self.get_alignment(spec.mem_id), ) - return end_of_allocation <= self.get_size( - spec.mem_id - ) and not self.placement_constraints.is_mem_id_in_blocklist(spec, spec.mem_id) + return ( + self.memory_id_is_valid[spec.mem_id] + and end_of_allocation <= self.get_size(spec.mem_id) + and not placement_constraints.is_mem_id_in_blocklist(spec, spec.mem_id) + ) + + @contextmanager + def block_memories_except(self, memory_id: int) -> Iterator[None]: + """Block all memories except the given memory_id.""" + try: + prev_valid = self.memory_id_is_valid.copy() + self.memory_id_is_valid = [False] * self.get_num_memories() + self.memory_id_is_valid[memory_id] = prev_valid[memory_id] + yield + finally: + self.memory_id_is_valid = prev_valid @abstractmethod def plan( self, - specs: set[TensorSpec], + specs: Iterable[TensorSpec], graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, extra_padding: int = 0, - prev_state: Optional[MemoryPlanningState] = None, - ) -> MemoryPlanningState: + ) -> None: """Plan memory allocation for the given tensor specs.""" pass - def __call__( + def _place_pinned_specs( self, - alignment: int, - specs: set[TensorSpec], + spec_with_abs_constraint: dict[ + TensorSpec, Optional[AbsolutePlacementConstraint] + ], + state: MemoryPlanningState, + placement_constraints: MemConstraints, + ) -> None: + """Place pinned specs with fixed mem_id AND offset.""" + # All specs that have absolute constraints that pin spec to mem id and offset. + pinned_specs = { + spec: c + for spec, c in spec_with_abs_constraint.items() + if c is not None and c.offset is not None + } + for spec, constraint in pinned_specs.items(): + spec.mem_id = constraint.pinned_memory_id + spec.mem_offset = constraint.offset + state.place_spec(spec) + placement_constraints.resolve_relative_loc_constraints(spec) + + def _place_memory_id_pinned_specs( + self, + spec_with_abs_constraint: dict[ + TensorSpec, Optional[AbsolutePlacementConstraint] + ], graph_module: torch.fx.GraphModule, graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, extra_padding: int = 0, - ) -> list[int]: + ) -> None: """Callable interface for ET memory planning.""" - self.populate_constraints(graph_module) - # First plan the memory allocation for specs without relative constraints. - specs_without_relative_constraints = set( - filter( - lambda spec: not self.placement_constraints.skipped_spec(spec), - specs, - ) - ) + for mem_id in range(1, self.get_num_memories()): + mem_id_pinned_specs: dict[TensorSpec, AbsolutePlacementConstraint] = { + spec: c + for spec, c in spec_with_abs_constraint.items() + if c is not None and c.pinned_memory_id == mem_id and c.offset is None + } + logging.error(f"Placing specs {mem_id_pinned_specs} for {mem_id=}") + + with self.block_memories_except(mem_id): + self.plan( + mem_id_pinned_specs, + graph_module, + graph_signature, + state, + placement_constraints, + extra_padding, + ) - # Call memory planning to get bufsizes. - state = self.plan( + for spec, constraint in spec_with_abs_constraint.items(): + if constraint is None: + continue + + logging.error(f"Placing spec {spec} with {constraint}") + + if not state.is_placed(spec): + raise MemoryError( + f"Cannot fit {spec} in memory {constraint.pinned_memory_id}" + ) + if ( + # Memory id is pinned, so we can't change it. + spec.mem_id != constraint.pinned_memory_id + or ( + # Memory offset is pinned, so we can't change it. + constraint.offset is not None + and spec.mem_offset != constraint.offset + ) + ): + raise MemoryError( + f"Incorrect memory planning for {spec} with {spec.mem_id=} and {spec.mem_offset=} for constraint {constraint}" + ) + # Resolve the relative constraints for the spec. + placement_constraints.resolve_relative_loc_constraints(spec) + + def _place_specs_with_no_absolute_constraints( + self, + spec_with_abs_constraint: dict[ + TensorSpec, Optional[AbsolutePlacementConstraint] + ], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, + extra_padding: int = 0, + ) -> None: + # Plan the memory allocation for specs without absolute or relative constraints. + specs_without_relative_constraints = { + spec: c + for spec, c in spec_with_abs_constraint.items() + if c is None and not placement_constraints.skipped_spec(spec) + } + self.plan( specs_without_relative_constraints, graph_module, graph_signature, + state, + placement_constraints, extra_padding, ) for spec in specs_without_relative_constraints: # And now honor the various memory location constraints (i.e., infer the memory # location of tensors in skip_specs from the constraints) for this spec. - self.placement_constraints.resolve_relative_loc_constraints(spec) + placement_constraints.resolve_relative_loc_constraints(spec) + + def plan_with_constraints( + self, + specs: Iterable[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + state: MemoryPlanningState, + placement_constraints: MemConstraints, + extra_padding: int = 0, + ) -> None: + """Callable interface for ET memory planning.""" + + spec_and_abs_constraints = { + spec: placement_constraints.get_absolute_placement_constraint(spec) + for spec in specs + } + + # Place specs that have both mem_id and offset constraints. + self._place_pinned_specs(spec_and_abs_constraints, state, placement_constraints) + + # Place specs that have both mem_id constraints. + self._place_memory_id_pinned_specs( + spec_and_abs_constraints, + graph_module, + graph_signature, + state, + placement_constraints, + extra_padding, + ) + + # Place specs that have no constraints. + self._place_specs_with_no_absolute_constraints( + spec_and_abs_constraints, + graph_module, + graph_signature, + state, + placement_constraints, + extra_padding, + ) + + def __call__( + self, + alignment: int, + specs: Iterable[TensorSpec], + graph_module: torch.fx.GraphModule, + graph_signature: ExportGraphSignature, + extra_padding: int = 0, + ) -> list[int]: + """Callable interface for ET memory planning.""" + + # Initialize state and constraints. + state, placement_constraints = self.populate_constraints(graph_module) + + self.plan_with_constraints( + specs, + graph_module, + graph_signature, + state, + placement_constraints, + extra_padding, + ) # At the end, all the keys in relative_loc_constraints should have been visited # and emptied. - assert not self.placement_constraints.relative_loc_constraints_exist() + assert not placement_constraints.relative_loc_constraints_exist() logging.debug(f"Memory planning algo found bufsizes: {state.bufsizes}") return state.bufsizes diff --git a/backends/cadence/aot/pass_utils.py b/backends/cadence/aot/pass_utils.py index 170c81f571e..935a448a8ae 100644 --- a/backends/cadence/aot/pass_utils.py +++ b/backends/cadence/aot/pass_utils.py @@ -7,14 +7,14 @@ # pyre-strict from dataclasses import dataclass -from typing import Callable, List, Optional, Set, Union +from typing import Callable, List, Optional, Set, Type, Union import torch from executorch.backends.cadence.aot.utils import get_edge_overload_packet from executorch.exir.dialects.edge._ops import EdgeOpOverload, EdgeOpOverloadPacket +from executorch.exir.pass_base import PassBase -from executorch.exir.pass_base import ExportPass from torch._ops import OpOverloadPacket @@ -25,40 +25,40 @@ def allow_lifetime_and_storage_overlap(opt_level: int) -> bool: # A dataclass that stores the attributes of an ExportPass. -@dataclass +@dataclass(frozen=True) class CadencePassAttribute: opt_level: Optional[int] = None debug_pass: bool = False # A dictionary that maps an ExportPass to its attributes. -ALL_CADENCE_PASSES: dict[ExportPass, CadencePassAttribute] = {} +ALL_CADENCE_PASSES: dict[Type[PassBase], CadencePassAttribute] = {} -def get_cadence_pass_attribute(p: ExportPass) -> Optional[CadencePassAttribute]: +def get_cadence_pass_attribute(p: Type[PassBase]) -> Optional[CadencePassAttribute]: return ALL_CADENCE_PASSES.get(p, None) # A decorator that registers a pass. def register_cadence_pass( pass_attribute: CadencePassAttribute, -) -> Callable[[ExportPass], ExportPass]: - def wrapper(cls: ExportPass) -> ExportPass: +) -> Callable[[Type[PassBase]], Type[PassBase]]: + def wrapper(cls: Type[PassBase]) -> Type[PassBase]: ALL_CADENCE_PASSES[cls] = pass_attribute return cls return wrapper -def get_all_available_cadence_passes() -> Set[ExportPass]: +def get_all_available_cadence_passes() -> Set[Type[PassBase]]: return set(ALL_CADENCE_PASSES.keys()) # Create a new filter to filter out relevant passes from all passes. def create_cadence_pass_filter( opt_level: int, debug: bool = False -) -> Callable[[ExportPass], bool]: - def _filter(p: ExportPass) -> bool: +) -> Callable[[Type[PassBase]], bool]: + def _filter(p: Type[PassBase]) -> bool: pass_attribute = get_cadence_pass_attribute(p) return ( pass_attribute is not None diff --git a/backends/cadence/aot/passes.py b/backends/cadence/aot/passes.py index 8355f7ef432..6b1921400b0 100644 --- a/backends/cadence/aot/passes.py +++ b/backends/cadence/aot/passes.py @@ -6,7 +6,7 @@ # pyre-strict -from typing import Any, List, Optional +from typing import Any, List, Type import torch import torch.fx @@ -71,7 +71,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: Argument = Any # pyre-ignore -def get_passes_in_default_order() -> List[ExportPass]: +def get_passes_in_default_order() -> list[Type[ExportPass]]: passes = [ InitializePipeline, RemoveRedundantOps.passes, @@ -91,12 +91,10 @@ def get_passes_in_default_order() -> List[ExportPass]: def get_cadence_passes( opt_level: int, -) -> List[Optional[PassResult]]: +) -> list[ExportPass]: passes = get_passes_in_default_order() pass_filter = create_cadence_pass_filter(opt_level) filtered_passes = [ - # pyre-ignore[20]: Expect argument graph_module - filtered_pass() - for filtered_pass in list(filter(pass_filter, passes)) + filtered_pass() for filtered_pass in list(filter(pass_filter, passes)) ] return filtered_passes diff --git a/backends/cadence/aot/program_builder.py b/backends/cadence/aot/program_builder.py new file mode 100644 index 00000000000..d5d9d2b0c29 --- /dev/null +++ b/backends/cadence/aot/program_builder.py @@ -0,0 +1,85 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. + +# pyre-strict + +from typing import Optional + +from executorch.backends.cadence.aot.graph_builder import GraphBuilder +from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir.pass_base import ProxyValue +from executorch.exir.verification.verifier import EXIREdgeDialectVerifier + +from torch import Tensor +from torch.export import ExportedProgram +from torch.export.graph_signature import ( + ExportGraphSignature, + InputKind, + InputSpec, + OutputKind, + OutputSpec, + TensorArgument, +) + + +class ProgramBuilder(GraphBuilder): + """Utility class to build a program from a graph module.""" + + def __init__(self) -> None: + self.input_specs: list[InputSpec] = [] + self.output_specs: list[OutputSpec] = [] + self.constants: dict[str, Tensor] = {} + self.state_dict: dict[str, Tensor] = {} + super().__init__() + + def insert_input_spec( + self, target: str, input_kind: InputKind, value: Tensor + ) -> None: + if input_kind == InputKind.USER_INPUT: + self.input_specs.append( + InputSpec(input_kind, TensorArgument(target), target=target) + ) + + def placeholder( + self, + target: str, + fake_tensor: Tensor, + input_kind: InputKind = InputKind.USER_INPUT, + ) -> ProxyValue: + placeholder = super().placeholder(target, fake_tensor) + self.insert_input_spec(target, input_kind, fake_tensor) + return placeholder + + def output( + self, results: list[ProxyValue], output_kinds: Optional[list[OutputKind]] = None + ) -> ProxyValue: + if output_kinds is None: + output_kinds = [OutputKind.USER_OUTPUT] * len(results) + for result, out_kind in zip(results, output_kinds): + self.output_specs.append( + OutputSpec(out_kind, TensorArgument(result.node.name), target=None) + ) + return super().output(results) + + def get_program(self) -> ExportedProgram: + gm = self.get_graph_module() + return ExportedProgram( + root=gm, + graph=gm.graph, + graph_signature=ExportGraphSignature( + input_specs=self.input_specs, output_specs=self.output_specs + ), + # pyre-ignore[6]: Incompatible parameter type. + constants=self.constants, + state_dict=self.state_dict, + range_constraints={}, + module_call_graph=[], + verifiers=[ + EXIREdgeDialectVerifier( + edge_compile_config=EdgeCompileConfig(_check_ir_validity=False), + class_only=True, + ) + ], + ) + + def get_edge_program(self) -> EdgeProgramManager: + return EdgeProgramManager(self.get_program()) diff --git a/backends/cadence/aot/quantizer/patterns.py b/backends/cadence/aot/quantizer/patterns.py index cd6a7287793..88c16139733 100644 --- a/backends/cadence/aot/quantizer/patterns.py +++ b/backends/cadence/aot/quantizer/patterns.py @@ -109,7 +109,7 @@ def get_anchors( ) def replacement_op(self) -> OpOverload: - return torch.ops.cadence.quantized_linear + return torch.ops.cadence.quantized_linear.default class AddPattern(QuantizationPattern): diff --git a/backends/cadence/aot/tests/test_memory_passes.py b/backends/cadence/aot/tests/test_memory_passes.py index df44ded8516..fe2eddc05ed 100644 --- a/backends/cadence/aot/tests/test_memory_passes.py +++ b/backends/cadence/aot/tests/test_memory_passes.py @@ -6,6 +6,7 @@ # pyre-strict +import logging import math import unittest from typing import cast, List, Optional, Sequence @@ -14,23 +15,36 @@ import torch from executorch.backends.cadence.aot import compiler from executorch.backends.cadence.aot.graph_builder import GraphBuilder -from executorch.backends.cadence.aot.memory_constraints import ConstraintsGenPass +from executorch.backends.cadence.aot.memory_constraints import ( + ConstraintsGenPass, + MemConstraints, +) from executorch.backends.cadence.aot.memory_planning import ( CadenceMemoryPlanning, find_peak_memory_usage, + PositionBasedGreedyWithHierarchy, +) +from executorch.backends.cadence.aot.memory_planning_algo import ( + MemoryPlanningAlgo, + MemoryPlanningState, ) from executorch.backends.cadence.aot.pass_utils import ( CadencePassAttribute, count_node, register_cadence_pass, ) +from executorch.backends.cadence.aot.program_builder import ProgramBuilder from executorch.backends.cadence.aot.typing_stubs import expand from executorch.backends.cadence.aot.utils import ( get_default_memory_config, MemoryConfig, ) +from executorch.exir import EdgeProgramManager, ExportedProgram from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.memory_planning import collect_specs_from_nodes +from executorch.exir.memory_planning import ( + collect_specs_from_nodes, + update_all_tensors_lifetime, +) from executorch.exir.pass_base import PassBase, PassResult from executorch.exir.passes.spec_prop_pass import SpecPropPass from executorch.exir.tests.models import MultiLayerPerceptron @@ -1019,7 +1033,6 @@ class DummyMemIdBlockConstraintGen(PassBase): """Blocks placement based on op type. add: blocks 2, 3 mul: blocks 1, 3 - """ def __init__(self, memory_constraints: MemoryConfig): @@ -1030,12 +1043,14 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: op="call_function", target=torch.ops.aten.add.Scalar ): spec = node.meta["spec"] + logging.error(f"add node: {node} {id(spec)=}") for mem_id in add_scalar_block_mem_ids: self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id) for node in graph_module.graph.find_nodes( op="call_function", target=torch.ops.aten.mul.Scalar ): spec = node.meta["spec"] + logging.error(f"mul node: {node} {id(spec)=}") for mem_id in mul_scalar_block_mem_ids: self.memory_constraints.add_mem_id_to_blocklist(spec, mem_id) @@ -1057,3 +1072,183 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult: spec = node.meta["spec"] self.assertIsNotNone(spec.mem_id) self.assertNotIn(spec.mem_id, mul_scalar_block_mem_ids) + + +class TestConstraintsBase(unittest.TestCase): + def get_view_then_add_graph(self) -> EdgeProgramManager: + builder = ProgramBuilder() + x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32)) + y = builder.placeholder("y", torch.ones(2, 15, dtype=torch.float32)) + x_reshape = builder.call_operator( + op=exir_ops.edge.aten.view_copy.default, + args=(x, [15]), + ) + add_x_y = builder.call_operator( + op=exir_ops.edge.aten.add.Tensor, + args=(x_reshape, y), + ) + builder.output([add_x_y]) + edge_program = builder.get_edge_program() + edge_program = edge_program.transform([SpecPropPass()]) + return edge_program + + @staticmethod + def get_aligned(num: int) -> int: + return ((num + 16 - 1) // 16) * 16 + + def _run_mem_planning( + self, + program: ExportedProgram, + memory_planning: MemoryPlanningAlgo, + state: MemoryPlanningState, + placement_constraints: MemConstraints, + ) -> None: + gm = program.graph_module + graph_signature = program.graph_signature + # Difficult to just filter the list of specs returned by this due to + # how we flag trainable weights. + _ = update_all_tensors_lifetime(gm, graph_signature) + + # Filter specs based on alloc_graph_input and alloc_graph_output + specs = set( + collect_specs_from_nodes( + gm.graph.nodes, + graph_signature, + do_assertion=False, + ignore_graph_input=False, + ignore_graph_output=False, + ignore_mutable_buffers=False, + ) + ) + memory_planning.plan_with_constraints( + specs, + gm, + # pyre-ignore[6] + None, + state, + placement_constraints, + ) + + +class TestAbsolutePlacementConstraint(TestConstraintsBase): + + def test_manually_planned_specs(self) -> None: + edge_program = self.get_view_then_add_graph() + x, y, x_view, add, _ = edge_program.exported_program().graph_module.graph.nodes + + # Create constraints for all nodes. + memory_config = MemoryConfig([1000, 10000]) + mem_planning = PositionBasedGreedyWithHierarchy(memory_config) + state = MemoryPlanningState(memory_config=memory_config) + placement_constraints = MemConstraints() + x_offset = 8000 + y_offset = 7000 + x_view_offset = 20 + add_offset = 400 + placement_constraints.add_absolute_placement_constraint(x, 2, x_offset) + placement_constraints.add_absolute_placement_constraint(y, 2, y_offset) + placement_constraints.add_absolute_placement_constraint( + x_view, 1, x_view_offset + ) + placement_constraints.add_absolute_placement_constraint(add, 1, add_offset) + + self._run_mem_planning( + edge_program.exported_program(), mem_planning, state, placement_constraints + ) + self.assertListEqual( + state.bufsizes, + [ + 0, + self.get_aligned(add_offset + 2 * 3 * 5 * 4), + self.get_aligned(x_offset + 3 * 5 * 4), + ], + msg=f"{state}", + ) + + def test_pinned_memory_id(self) -> None: + edge_program = self.get_view_then_add_graph() + x, y, x_view, add, _ = edge_program.exported_program().graph_module.graph.nodes + # Create both mem_id+mem_offset and mem_offset constraints for all nodes. + memory_config = MemoryConfig([1000, 10000]) + mem_planning = PositionBasedGreedyWithHierarchy(memory_config) + state = MemoryPlanningState(memory_config=memory_config) + placement_constraints = MemConstraints() + x_offset = None + y_offset = 8000 + x_view_offset = 800 + add_offset = None + placement_constraints.add_absolute_placement_constraint(x, 2, x_offset) + placement_constraints.add_absolute_placement_constraint(y, 2, y_offset) + placement_constraints.add_absolute_placement_constraint( + x_view, 1, x_view_offset + ) + placement_constraints.add_absolute_placement_constraint(add, 1, add_offset) + + self._run_mem_planning( + edge_program.exported_program(), mem_planning, state, placement_constraints + ) + self.assertListEqual( + state.bufsizes, + [ + 0, + self.get_aligned(x_view_offset + 3 * 5 * 4), + self.get_aligned(y_offset + 2 * 3 * 5 * 4), + ], + msg=f"{state}", + ) + + +class TestMixedPlacementConstraints(TestConstraintsBase): + def get_slice_graph(self) -> EdgeProgramManager: + builder = ProgramBuilder() + x = builder.placeholder("x", torch.ones(3, 5, dtype=torch.float32)) + x_slice = builder.call_operator( + op=exir_ops.edge.aten.slice_copy.Tensor, + args=(x, 0, 2), + ) + builder.output([x_slice]) + edge_program = builder.get_edge_program() + edge_program = edge_program.transform([SpecPropPass()]) + return edge_program + + def test_slice_pinned_output(self) -> None: + edge_program = self.get_slice_graph() + x, x_slice, _ = edge_program.exported_program().graph_module.graph.nodes + # Create both mem_id+mem_offset and mem_offset constraints for all nodes. + memory_config = MemoryConfig([1000]) + mem_planning = PositionBasedGreedyWithHierarchy(memory_config) + state = MemoryPlanningState(memory_config=memory_config) + placement_constraints = MemConstraints() + x_offset = 20 + placement_constraints.add_absolute_placement_constraint(x, 1, x_offset) + placement_constraints.add_relative_placement_constraint( + x, x_slice, 40, update_lifetime=False + ) + self._run_mem_planning( + edge_program.exported_program(), mem_planning, state, placement_constraints + ) + + # Check that x is placed correctly at `x_offset` and x_slice is placed at `x_offset + 40`. + self.assertEqual(x.meta["spec"].mem_id, 1) + self.assertEqual(x.meta["spec"].mem_offset, x_offset) + self.assertEqual(x_slice.meta["spec"].mem_id, 1) + self.assertEqual(x_slice.meta["spec"].mem_offset, x_offset + 2 * 5 * 4) + + def test_slice_pinned_input_fail(self) -> None: + edge_program = self.get_slice_graph() + x, x_slice, _ = edge_program.exported_program().graph_module.graph.nodes + # Create both mem_id+mem_offset and mem_offset constraints for all nodes. + placement_constraints = MemConstraints() + x_slice_offset = 20 + x_offset = 40 + pin_memory_id = 1 + placement_constraints.add_absolute_placement_constraint( + x_slice, pin_memory_id, x_slice_offset + ) + with self.assertRaisesRegex( + RuntimeError, + f"Cannot add relative placement constraint for aten_slice_copy_tensor with non-zero offset {x_offset} when it has an absolute placement constraint AbsolutePlacementConstraint\\(pinned_memory_id={pin_memory_id}, offset={x_slice_offset}\\)", + ): + placement_constraints.add_relative_placement_constraint( + x, x_slice, x_offset, update_lifetime=False + ) diff --git a/backends/cadence/aot/tests/test_pass_filter.py b/backends/cadence/aot/tests/test_pass_filter.py index 9bfd71556bd..ad89ff06f4f 100644 --- a/backends/cadence/aot/tests/test_pass_filter.py +++ b/backends/cadence/aot/tests/test_pass_filter.py @@ -10,7 +10,7 @@ import unittest from copy import deepcopy -from typing import Callable, Dict +from typing import Callable, Type from executorch.backends.cadence.aot import pass_utils from executorch.backends.cadence.aot.pass_utils import ( @@ -20,7 +20,7 @@ register_cadence_pass, ) -from executorch.exir.pass_base import ExportPass +from executorch.exir.pass_base import ExportPass, PassBase class TestBase(unittest.TestCase): @@ -36,9 +36,9 @@ def tearDown(self) -> None: pass_utils.ALL_CADENCE_PASSES = self._all_passes_original def get_filtered_passes( - self, filter_: Callable[[ExportPass], bool] - ) -> Dict[ExportPass, CadencePassAttribute]: - return {cls: attr for cls, attr in ALL_CADENCE_PASSES.items() if filter_(cls)} + self, filter_: Callable[[Type[PassBase]], bool] + ) -> dict[Type[PassBase], CadencePassAttribute]: + return {c: attr for c, attr in ALL_CADENCE_PASSES.items() if filter_(c)} # Test pass registration