Skip to content

Commit 4b1585f

Browse files
hsharma35facebook-github-bot
authored andcommitted
Add support for absolute mem_id/offset placement constraints. (#12266)
Summary: Add placement constraints to allow placement of tensors in specific mem ID. For cadence backend, this allows placement of tensors in specific DTCM banks for iDMA ops. Reviewed By: skrtskrtfb Differential Revision: D77061574
1 parent a022db0 commit 4b1585f

File tree

10 files changed

+685
-137
lines changed

10 files changed

+685
-137
lines changed

backends/cadence/aot/TARGETS

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,21 @@ python_library(
184184
],
185185
)
186186

187+
python_library(
188+
name = "program_builder",
189+
srcs = [
190+
"program_builder.py",
191+
],
192+
typing = True,
193+
deps = [
194+
":graph_builder",
195+
"fbcode//caffe2:torch",
196+
"fbcode//executorch/exir:lib",
197+
"fbcode//executorch/exir:pass_base",
198+
"fbcode//executorch/exir/verification:verifier",
199+
],
200+
)
201+
187202
python_library(
188203
name = "fuse_ops",
189204
srcs = [
@@ -508,6 +523,7 @@ python_unittest(
508523
":typing_stubs",
509524
":ops_registrations",
510525
":pass_utils",
526+
":program_builder",
511527
"//caffe2:torch",
512528
"//executorch/exir:memory",
513529
"//executorch/exir/dialects:lib",

backends/cadence/aot/memory_constraints.py

Lines changed: 116 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,10 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
# pyre-unsafe
7+
# pyre-strict
88

99
import logging
1010
import math
11-
import typing
1211
from collections import defaultdict
1312
from dataclasses import dataclass
1413
from typing import Callable, cast, DefaultDict, Iterable, Optional, Sequence, TypeAlias
@@ -28,19 +27,38 @@
2827

2928

3029
@dataclass(frozen=True)
31-
class SourceInfo:
30+
class RelativePlacementConstraint:
3231
"""Information of source node and offset used for views."""
3332

3433
source: torch.fx.Node
3534
offset: int = 0
3635

3736

37+
@dataclass(frozen=True)
38+
class AbsolutePlacementConstraint:
39+
"""Information on placement constraint memory id and offset."""
40+
41+
pinned_memory_id: int
42+
43+
# If offset is None, then the tensor can be placed anywhere in the memory id.
44+
offset: Optional[int] = None
45+
46+
3847
class MemConstraints:
3948
"""
4049
This class contains all the tensor placement constraints that we create
4150
during memory planning.
42-
Any tensor whose placement is derived off another tensor via a constraint
43-
is not included in memory planning, and is marked as skipped.
51+
52+
We have two types of placement constraints:
53+
1. Relative placement constraints: These are constraints that specify the
54+
relative placement of a tensor with respect to another tensor. For
55+
example, when slice dim is 0, slice output can be placed relative to
56+
their inputs and the op can be replaced with a nop.
57+
2. Absolute placement constraints: These are constraints that specify the
58+
absolute placement of a tensor either in a specific memory id, or both
59+
a specific memory id and offset. For example, for operators that require
60+
a specific memory id + offset for we can use this constraint to specify
61+
location of inputs/outputs or even temporary buffers.
4462
"""
4563

4664
def __init__(
@@ -62,29 +80,38 @@ def __init__(
6280
# A set of tensor spec ids that must be skipped during memory allocation.
6381
# The exact mem_id and offset of the skipped tensors will be computed from
6482
# the constraints.
65-
self._source_node: dict[int, SourceInfo] = {}
83+
self._relative_placement_constraint: dict[int, RelativePlacementConstraint] = {}
6684

6785
# A map from `id(TensorSpec)` to a set of mem_ids that cannot be used for
6886
# allocating the tensor.
6987
self._mem_id_blocklist: dict[int, set[int]] = {}
7088

71-
def get_source_info(self, node: torch.fx.Node) -> Optional[SourceInfo]:
89+
# A map from `id(TensorSpec)` to a AbsolutePlacementConstraint that specifies mem_id and optionally exact offset.
90+
self._absolute_placement_constraints: dict[int, AbsolutePlacementConstraint] = (
91+
{}
92+
)
93+
94+
def get_relative_placement_source(
95+
self, node: torch.fx.Node
96+
) -> Optional[RelativePlacementConstraint]:
7297
spec = node.meta.get("spec")
7398
spec_id = id(spec)
74-
if spec_id not in self._source_node:
99+
if spec_id not in self._relative_placement_constraint:
75100
return None
76-
return self._source_node[spec_id]
101+
return self._relative_placement_constraint[spec_id]
77102

78-
def set_source_info(
79-
self, dependent: torch.fx.Node, source_info: SourceInfo
103+
def set_relative_placement_constraint(
104+
self,
105+
dependent: torch.fx.Node,
106+
placement_constraint: RelativePlacementConstraint,
80107
) -> None:
81108
dependent_spec = dependent.meta.get("spec")
82109
spec_id = id(dependent_spec)
83-
self._source_node[spec_id] = source_info
84-
if self.is_memory_planned(source_info.source):
110+
self._relative_placement_constraint[spec_id] = placement_constraint
111+
if self.is_memory_planned(placement_constraint.source):
85112
# Only add dependent nodes if source node needs memory planning.
86113
self.unresolved_loc_constraints[
87-
id(source_info.source.meta.get("spec"))
114+
id(placement_constraint.source.meta.get("spec"))
88115
].add(dependent)
89116

90117
def add_mem_id_to_blocklist(self, spec: TensorSpec, mem_id: int) -> None:
@@ -111,7 +138,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
111138
node --> view
112139
--> relu (or some other op that can be in-place)
113140
"""
114-
if node_source_info := self.get_source_info(node):
141+
if node_source_info := self.get_relative_placement_source(node):
115142
node_spec = node.meta.get("spec")
116143
node_source_spec = node_source_info.source.meta.get("spec")
117144
return (
@@ -121,7 +148,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
121148
and self.is_alias_of(node_source_info.source, other_node)
122149
)
123150

124-
if self.get_source_info(other_node) is not None:
151+
if self.get_relative_placement_source(other_node) is not None:
125152
return self.is_alias_of(other_node, node)
126153

127154
return node == other_node
@@ -132,14 +159,14 @@ def relative_loc_constraints_exist(self) -> bool:
132159

133160
# Return true if the spec is marked as skipped
134161
def skipped_spec(self, spec: TensorSpec) -> bool:
135-
return id(spec) in self._source_node
162+
return id(spec) in self._relative_placement_constraint
136163

137164
def is_memory_planned(
138165
self,
139166
node: torch.fx.Node,
140167
) -> bool:
141168
"""Return true if the node is either (1) a parameter, or (2) a placeholder."""
142-
if (source_info := self.get_source_info(node)) is not None:
169+
if (source_info := self.get_relative_placement_source(node)) is not None:
143170
# If node has relative placement constraints, then check the source.
144171
return self.is_memory_planned(source_info.source)
145172
# Check if any node is a param.
@@ -183,7 +210,7 @@ def resolve_relative_loc_constraints(self, spec: TensorSpec) -> None:
183210

184211
assert isinstance(spec, TensorSpec)
185212
for dependent_node in self.unresolved_loc_constraints[spec_id]:
186-
source_info = self.get_source_info(dependent_node)
213+
source_info = self.get_relative_placement_source(dependent_node)
187214
assert source_info is not None
188215
dependent_spec = cast(TensorSpec, dependent_node.meta.get("spec"))
189216
dependent_spec.mem_id = spec.mem_id
@@ -202,19 +229,21 @@ def update_children_nodes(self, node: torch.fx.Node, update_lifetime: bool) -> N
202229
children_nodes = self.unresolved_loc_constraints[id(node.meta.get("spec"))]
203230
self.unresolved_loc_constraints.pop(id(node.meta.get("spec")))
204231

205-
source_info = self.get_source_info(node)
232+
source_info = self.get_relative_placement_source(node)
206233
assert source_info is not None
207234

208235
for child_node in children_nodes:
209-
child_info = self._source_node.pop(id(child_node.meta.get("spec")))
210-
self.generate_location_constraint(
236+
child_info = self._relative_placement_constraint.pop(
237+
id(child_node.meta.get("spec"))
238+
)
239+
self.add_relative_placement_constraint(
211240
source_info.source,
212241
child_node,
213242
offset=source_info.offset + child_info.offset,
214243
update_lifetime=update_lifetime,
215244
)
216245

217-
def generate_location_constraint(
246+
def add_relative_placement_constraint(
218247
self,
219248
source: torch.fx.Node,
220249
dependent: torch.fx.Node,
@@ -230,29 +259,26 @@ def generate_location_constraint(
230259
logging.debug(f"Adding constraint {dependent} = {source} + {offset=}")
231260

232261
# Assert that both source and dependent node are tensors.
233-
if (info := self.get_source_info(source)) is not None:
234-
return self.generate_location_constraint(
235-
info.source, dependent, offset + info.offset, update_lifetime
236-
)
262+
if (info := self.get_relative_placement_source(source)) is not None:
263+
source = info.source
264+
offset += info.offset
237265

238-
if (info := self.get_source_info(dependent)) is not None:
266+
if (info := self.get_relative_placement_source(dependent)) is not None:
239267
# Dependent node can only be an alias (same size, offset = 0).
240268
assert self.is_alias_of(
241269
info.source, dependent
242270
), f"Multiple constraints for allocation of {dependent}. Previous constraint: {info} new constraint: {source=} {offset=}"
243-
return self.generate_location_constraint(
244-
source, info.source, offset, update_lifetime=update_lifetime
245-
)
271+
dependent = info.source
246272

247273
# Add the dependent spec to skip list. Its memory offset will be computed
248274
# after the output tensor is allocated space.
249-
source_info = SourceInfo(source=source, offset=offset)
250-
self.set_source_info(dependent, source_info)
275+
source_info = RelativePlacementConstraint(source=source, offset=offset)
276+
self.set_relative_placement_constraint(dependent, source_info)
251277

252278
# If update_lifetime is True, take a union of the lifetime of representaitve
253279
# and dependent tensors; this will become the new lifetime of source tensor.
280+
dependent_spec = dependent.meta.get("spec")
254281
if update_lifetime:
255-
dependent_spec = dependent.meta.get("spec")
256282
source_spec = source.meta.get("spec")
257283
source.meta.get("spec").lifetime = [
258284
min(source_spec.lifetime[0], dependent_spec.lifetime[0]),
@@ -261,6 +287,49 @@ def generate_location_constraint(
261287

262288
self.update_children_nodes(dependent, update_lifetime)
263289

290+
abs_constraint = self.get_absolute_placement_constraint(dependent_spec)
291+
if abs_constraint is None:
292+
return
293+
294+
# Dependent node has an absolute placement constraint.
295+
# If the offset is not 0, then we cannot add a relative placement constraint.
296+
if not self.is_alias_of(dependent, source):
297+
raise RuntimeError(
298+
f"Cannot add relative placement constraint for {dependent} with non-zero offset {offset} when it has an absolute placement constraint {abs_constraint}"
299+
)
300+
301+
# Add the absolute placement constraint to the source node.
302+
self._absolute_placement_constraints.pop(id(dependent_spec))
303+
self.add_absolute_placement_constraint(
304+
source, abs_constraint.pinned_memory_id, abs_constraint.offset
305+
)
306+
307+
def add_absolute_placement_constraint(
308+
self, node: torch.fx.Node, pinned_memory_id: int, offset: Optional[int] = None
309+
) -> None:
310+
"""Add a memory pinning constraint for `node` to `mem_id`."""
311+
logging.debug(
312+
f"Adding memory pinning constraint {node=} = {pinned_memory_id=} at {offset=}"
313+
)
314+
source_node: torch.fx.Node = node
315+
if (info := self.get_relative_placement_source(node)) is not None:
316+
assert self.is_alias_of(info.source, node)
317+
logging.debug(
318+
f"Setting {node} to {info.source} + {offset=}. Pinned to {pinned_memory_id=}"
319+
)
320+
source_node = info.source
321+
self._absolute_placement_constraints[id(source_node.meta.get("spec"))] = (
322+
AbsolutePlacementConstraint(
323+
pinned_memory_id=pinned_memory_id, offset=offset
324+
)
325+
)
326+
327+
def get_absolute_placement_constraint(
328+
self, spec: TensorSpec
329+
) -> Optional[AbsolutePlacementConstraint]:
330+
"""Return true if `node` has an absolute placement constraint."""
331+
return self._absolute_placement_constraints.get(id(spec), None)
332+
264333

265334
def get_relative_offsets_of_cat_tensors(
266335
cat_tensors: Sequence[torch.fx.Node],
@@ -342,7 +411,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
342411

343412
def is_slice_view(self, node: torch.fx.Node) -> bool:
344413
"""Return if `node` has constraints and is not an alias of another node."""
345-
if (source_info := self.constraint.get_source_info(node)) is not None:
414+
if (
415+
source_info := self.constraint.get_relative_placement_source(node)
416+
) is not None:
346417
return not self.constraint.is_alias_of(source_info.source, node)
347418
return False
348419

@@ -426,7 +497,9 @@ def is_removable_cat_op(
426497
return True
427498

428499
# Currently the contiguity constraints are generated by cat operator.
429-
def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule):
500+
def compute_cat_contiguity_constraints(
501+
self, graph_module: torch.fx.GraphModule
502+
) -> None:
430503
for node in graph_module.graph.nodes:
431504
# Only compute relative constraints if the cat node can be replaced with
432505
# its nop version
@@ -448,7 +521,9 @@ def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule)
448521
# Get the relative offsets for each tensor to be concatenated.
449522
relative_offsets = get_relative_offsets_of_cat_tensors(cat_tensors)
450523
for arg, offset in zip(cat_tensors, relative_offsets):
451-
self.constraint.generate_location_constraint(node, arg, offset=offset)
524+
self.constraint.add_relative_placement_constraint(
525+
node, arg, offset=offset
526+
)
452527

453528
# Update the lifetimes of the args to that of the output tensor, so
454529
# that they don't get overwritten
@@ -474,7 +549,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
474549
for node in graph_module.graph.nodes:
475550
if node.op != "call_function" or node.target != memory.view:
476551
continue
477-
self.constraint.generate_location_constraint(node.args[0], node)
552+
self.constraint.add_relative_placement_constraint(node.args[0], node)
478553

479554

480555
@register_cadence_pass(CadencePassAttribute(opt_level=2))
@@ -544,7 +619,7 @@ def removable_slice_or_select_op(
544619
# the input and output tensor.
545620
def compute_slice_and_select_loc_constraints(
546621
self, graph_module: torch.fx.GraphModule
547-
):
622+
) -> None:
548623
for node in graph_module.graph.nodes:
549624
# Only compute relative constraints if the slice node can be
550625
# replaced with its nop version
@@ -563,7 +638,7 @@ def compute_slice_and_select_loc_constraints(
563638
# And now generate location constraint between input and output
564639
# tensors of slice node
565640
arg = node.args[0]
566-
self.constraint.generate_location_constraint(
641+
self.constraint.add_relative_placement_constraint(
567642
arg,
568643
node,
569644
offset=offset,
@@ -607,12 +682,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
607682
filtered_passes = [
608683
mcg_pass(self.mem_constraints)
609684
for mcg_pass in cast(
610-
list[
611-
typing.Callable[
612-
[MemConstraints],
613-
typing.Callable[[torch.fx.GraphModule], Optional[PassResult]],
614-
]
615-
],
685+
list[ConstraintsGenPass],
616686
# pyre-ignore[6]: Incompatible parameter type.
617687
list(filter(pass_filter, constraint_gen_passes)),
618688
)

0 commit comments

Comments
 (0)