Skip to content

Add support for absolute mem_id/offset placement constraints. #12266

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions backends/cadence/aot/TARGETS
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,21 @@ python_library(
],
)

python_library(
name = "program_builder",
srcs = [
"program_builder.py",
],
typing = True,
deps = [
":graph_builder",
"fbcode//caffe2:torch",
"fbcode//executorch/exir:lib",
"fbcode//executorch/exir:pass_base",
"fbcode//executorch/exir/verification:verifier",
],
)

python_library(
name = "fuse_ops",
srcs = [
Expand Down Expand Up @@ -508,6 +523,7 @@ python_unittest(
":typing_stubs",
":ops_registrations",
":pass_utils",
":program_builder",
"//caffe2:torch",
"//executorch/exir:memory",
"//executorch/exir/dialects:lib",
Expand Down
162 changes: 116 additions & 46 deletions backends/cadence/aot/memory_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-unsafe
# pyre-strict

import logging
import math
import typing
from collections import defaultdict
from dataclasses import dataclass
from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias
Expand All @@ -28,19 +27,38 @@


@dataclass(frozen=True)
class SourceInfo:
class RelativePlacementConstraint:
"""Information of source node and offset used for views."""

source: torch.fx.Node
offset: int = 0


@dataclass(frozen=True)
class AbsolutePlacementConstraint:
"""Information on placement constraint memory id and offset."""

pinned_memory_id: int

# If offset is None, then the tensor can be placed anywhere in the memory id.
offset: Optional[int] = None


class MemConstraints:
"""
This class contains all the tensor placement constraints that we create
during memory planning.
Any tensor whose placement is derived off another tensor via a constraint
is not included in memory planning, and is marked as skipped.

We have two types of placement constraints:
1. Relative placement constraints: These are constraints that specify the
relative placement of a tensor with respect to another tensor. For
example, when slice dim is 0, slice output can be placed relative to
their inputs and the op can be replaced with a nop.
2. Absolute placement constraints: These are constraints that specify the
absolute placement of a tensor either in a specific memory id, or both
a specific memory id and offset. For example, for operators that require
a specific memory id + offset for we can use this constraint to specify
location of inputs/outputs or even temporary buffers.
"""

def __init__(
Expand All @@ -62,29 +80,38 @@ def __init__(
# A set of tensor spec ids that must be skipped during memory allocation.
# The exact mem_id and offset of the skipped tensors will be computed from
# the constraints.
self._source_node: dict[int, SourceInfo] = {}
self._relative_placement_constraint: dict[int, RelativePlacementConstraint] = {}

# A map from `id(TensorSpec)` to a set of mem_ids that cannot be used for
# allocating the tensor.
self._mem_id_blocklist: dict[int, set[int]] = {}

def get_source_info(self, node: torch.fx.Node) -> Optional[SourceInfo]:
# A map from `id(TensorSpec)` to a AbsolutePlacementConstraint that specifies mem_id and optionally exact offset.
self._absolute_placement_constraints: dict[int, AbsolutePlacementConstraint] = (
{}
)

def get_relative_placement_source(
self, node: torch.fx.Node
) -> Optional[RelativePlacementConstraint]:
spec = node.meta.get("spec")
spec_id = id(spec)
if spec_id not in self._source_node:
if spec_id not in self._relative_placement_constraint:
return None
return self._source_node[spec_id]
return self._relative_placement_constraint[spec_id]

def set_source_info(
self, dependent: torch.fx.Node, source_info: SourceInfo
def set_relative_placement_constraint(
self,
dependent: torch.fx.Node,
placement_constraint: RelativePlacementConstraint,
) -> None:
dependent_spec = dependent.meta.get("spec")
spec_id = id(dependent_spec)
self._source_node[spec_id] = source_info
if self.is_memory_planned(source_info.source):
self._relative_placement_constraint[spec_id] = placement_constraint
if self.is_memory_planned(placement_constraint.source):
# Only add dependent nodes if source node needs memory planning.
self.unresolved_loc_constraints[
id(source_info.source.meta.get("spec"))
id(placement_constraint.source.meta.get("spec"))
].add(dependent)

def add_mem_id_to_blocklist(self, spec: TensorSpec, mem_id: int) -> None:
Expand All @@ -111,7 +138,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
node --> view
--> relu (or some other op that can be in-place)
"""
if node_source_info := self.get_source_info(node):
if node_source_info := self.get_relative_placement_source(node):
node_spec = node.meta.get("spec")
node_source_spec = node_source_info.source.meta.get("spec")
return (
Expand All @@ -121,7 +148,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
and self.is_alias_of(node_source_info.source, other_node)
)

if self.get_source_info(other_node) is not None:
if self.get_relative_placement_source(other_node) is not None:
return self.is_alias_of(other_node, node)

return node == other_node
Expand All @@ -132,14 +159,14 @@ def relative_loc_constraints_exist(self) -> bool:

# Return true if the spec is marked as skipped
def skipped_spec(self, spec: TensorSpec) -> bool:
return id(spec) in self._source_node
return id(spec) in self._relative_placement_constraint

def is_memory_planned(
self,
node: torch.fx.Node,
) -> bool:
"""Return true if the node is either (1) a parameter, or (2) a placeholder."""
if (source_info := self.get_source_info(node)) is not None:
if (source_info := self.get_relative_placement_source(node)) is not None:
# If node has relative placement constraints, then check the source.
return self.is_memory_planned(source_info.source)
# Check if any node is a param.
Expand Down Expand Up @@ -183,7 +210,7 @@ def resolve_relative_loc_constraints(self, spec: TensorSpec) -> None:

assert isinstance(spec, TensorSpec)
for dependent_node in self.unresolved_loc_constraints[spec_id]:
source_info = self.get_source_info(dependent_node)
source_info = self.get_relative_placement_source(dependent_node)
assert source_info is not None
dependent_spec = cast(TensorSpec, dependent_node.meta.get("spec"))
dependent_spec.mem_id = spec.mem_id
Expand All @@ -202,19 +229,21 @@ def update_children_nodes(self, node: torch.fx.Node, update_lifetime: bool) -> N
children_nodes = self.unresolved_loc_constraints[id(node.meta.get("spec"))]
self.unresolved_loc_constraints.pop(id(node.meta.get("spec")))

source_info = self.get_source_info(node)
source_info = self.get_relative_placement_source(node)
assert source_info is not None

for child_node in children_nodes:
child_info = self._source_node.pop(id(child_node.meta.get("spec")))
self.generate_location_constraint(
child_info = self._relative_placement_constraint.pop(
id(child_node.meta.get("spec"))
)
self.add_relative_placement_constraint(
source_info.source,
child_node,
offset=source_info.offset + child_info.offset,
update_lifetime=update_lifetime,
)

def generate_location_constraint(
def add_relative_placement_constraint(
self,
source: torch.fx.Node,
dependent: torch.fx.Node,
Expand All @@ -230,29 +259,26 @@ def generate_location_constraint(
logging.debug(f"Adding constraint {dependent} = {source} + {offset=}")

# Assert that both source and dependent node are tensors.
if (info := self.get_source_info(source)) is not None:
return self.generate_location_constraint(
info.source, dependent, offset + info.offset, update_lifetime
)
if (info := self.get_relative_placement_source(source)) is not None:
source = info.source
offset += info.offset

if (info := self.get_source_info(dependent)) is not None:
if (info := self.get_relative_placement_source(dependent)) is not None:
# Dependent node can only be an alias (same size, offset = 0).
assert self.is_alias_of(
info.source, dependent
), f"Multiple constraints for allocation of {dependent}. Previous constraint: {info} new constraint: {source=} {offset=}"
return self.generate_location_constraint(
source, info.source, offset, update_lifetime=update_lifetime
)
dependent = info.source

# Add the dependent spec to skip list. Its memory offset will be computed
# after the output tensor is allocated space.
source_info = SourceInfo(source=source, offset=offset)
self.set_source_info(dependent, source_info)
source_info = RelativePlacementConstraint(source=source, offset=offset)
self.set_relative_placement_constraint(dependent, source_info)

# If update_lifetime is True, take a union of the lifetime of representaitve
# and dependent tensors; this will become the new lifetime of source tensor.
dependent_spec = dependent.meta.get("spec")
if update_lifetime:
dependent_spec = dependent.meta.get("spec")
source_spec = source.meta.get("spec")
source.meta.get("spec").lifetime = [
min(source_spec.lifetime[0], dependent_spec.lifetime[0]),
Expand All @@ -261,6 +287,49 @@ def generate_location_constraint(

self.update_children_nodes(dependent, update_lifetime)

abs_constraint = self.get_absolute_placement_constraint(dependent_spec)
if abs_constraint is None:
return

# Dependent node has an absolute placement constraint.
# If the offset is not 0, then we cannot add a relative placement constraint.
if not self.is_alias_of(dependent, source):
raise RuntimeError(
f"Cannot add relative placement constraint for {dependent} with non-zero offset {offset} when it has an absolute placement constraint {abs_constraint}"
)

# Add the absolute placement constraint to the source node.
self._absolute_placement_constraints.pop(id(dependent_spec))
self.add_absolute_placement_constraint(
source, abs_constraint.pinned_memory_id, abs_constraint.offset
)

def add_absolute_placement_constraint(
self, node: torch.fx.Node, pinned_memory_id: int, offset: Optional[int] = None
) -> None:
"""Add a memory pinning constraint for `node` to `mem_id`."""
logging.debug(
f"Adding memory pinning constraint {node=} = {pinned_memory_id=} at {offset=}"
)
source_node: torch.fx.Node = node
if (info := self.get_relative_placement_source(node)) is not None:
assert self.is_alias_of(info.source, node)
logging.debug(
f"Setting {node} to {info.source} + {offset=}. Pinned to {pinned_memory_id=}"
)
source_node = info.source
self._absolute_placement_constraints[id(source_node.meta.get("spec"))] = (
AbsolutePlacementConstraint(
pinned_memory_id=pinned_memory_id, offset=offset
)
)

def get_absolute_placement_constraint(
self, spec: TensorSpec
) -> Optional[AbsolutePlacementConstraint]:
"""Return true if `node` has an absolute placement constraint."""
return self._absolute_placement_constraints.get(id(spec), None)


def get_relative_offsets_of_cat_tensors(
cat_tensors: Sequence[torch.fx.Node],
Expand Down Expand Up @@ -342,7 +411,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:

def is_slice_view(self, node: torch.fx.Node) -> bool:
"""Return if `node` has constraints and is not an alias of another node."""
if (source_info := self.constraint.get_source_info(node)) is not None:
if (
source_info := self.constraint.get_relative_placement_source(node)
) is not None:
return not self.constraint.is_alias_of(source_info.source, node)
return False

Expand Down Expand Up @@ -426,7 +497,9 @@ def is_removable_cat_op(
return True

# Currently the contiguity constraints are generated by cat operator.
def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule):
def compute_cat_contiguity_constraints(
self, graph_module: torch.fx.GraphModule
) -> None:
for node in graph_module.graph.nodes:
# Only compute relative constraints if the cat node can be replaced with
# its nop version
Expand All @@ -448,7 +521,9 @@ def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule)
# Get the relative offsets for each tensor to be concatenated.
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
for arg, offset in zip(cat_tensors, relative_offsets):
self.constraint.generate_location_constraint(node, arg, offset=offset)
self.constraint.add_relative_placement_constraint(
node, arg, offset=offset
)

# Update the lifetimes of the args to that of the output tensor, so
# that they don't get overwritten
Expand All @@ -474,7 +549,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
for node in graph_module.graph.nodes:
if node.op != "call_function" or node.target != memory.view:
continue
self.constraint.generate_location_constraint(node.args[0], node)
self.constraint.add_relative_placement_constraint(node.args[0], node)


@register_cadence_pass(CadencePassAttribute(opt_level=2))
Expand Down Expand Up @@ -544,7 +619,7 @@ def removable_slice_or_select_op(
# the input and output tensor.
def compute_slice_and_select_loc_constraints(
self, graph_module: torch.fx.GraphModule
):
) -> None:
for node in graph_module.graph.nodes:
# Only compute relative constraints if the slice node can be
# replaced with its nop version
Expand All @@ -563,7 +638,7 @@ def compute_slice_and_select_loc_constraints(
# And now generate location constraint between input and output
# tensors of slice node
arg = node.args[0]
self.constraint.generate_location_constraint(
self.constraint.add_relative_placement_constraint(
arg,
node,
offset=offset,
Expand Down Expand Up @@ -607,12 +682,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
filtered_passes = [
mcg_pass(self.mem_constraints)
for mcg_pass in cast(
list[
typing.Callable[
[MemConstraints],
typing.Callable[[torch.fx.GraphModule], Optional[PassResult]],
]
],
list[ConstraintsGenPass],
# pyre-ignore[6]: Incompatible parameter type.
list(filter(pass_filter, constraint_gen_passes)),
)
Expand Down
Loading
Loading