4
4
# This source code is licensed under the BSD-style license found in the
5
5
# LICENSE file in the root directory of this source tree.
6
6
7
- # pyre-unsafe
7
+ # pyre-strict
8
8
9
9
import logging
10
10
import math
11
- import typing
12
11
from collections import defaultdict
13
12
from dataclasses import dataclass
14
13
from typing import Callable , cast , DefaultDict , Iterable , Optional , Sequence , TypeAlias
28
27
29
28
30
29
@dataclass (frozen = True )
31
- class SourceInfo :
30
+ class RelativePlacementConstraint :
32
31
"""Information of source node and offset used for views."""
33
32
34
33
source : torch .fx .Node
35
34
offset : int = 0
36
35
37
36
37
+ @dataclass (frozen = True )
38
+ class AbsolutePlacementConstraint :
39
+ """Information on placement constraint memory id and offset."""
40
+
41
+ pinned_memory_id : int
42
+
43
+ # If offset is None, then the tensor can be placed anywhere in the memory id.
44
+ offset : Optional [int ] = None
45
+
46
+
38
47
class MemConstraints :
39
48
"""
40
49
This class contains all the tensor placement constraints that we create
41
50
during memory planning.
42
- Any tensor whose placement is derived off another tensor via a constraint
43
- is not included in memory planning, and is marked as skipped.
51
+
52
+ We have two types of placement constraints:
53
+ 1. Relative placement constraints: These are constraints that specify the
54
+ relative placement of a tensor with respect to another tensor. For
55
+ example, when slice dim is 0, slice output can be placed relative to
56
+ their inputs and the op can be replaced with a nop.
57
+ 2. Absolute placement constraints: These are constraints that specify the
58
+ absolute placement of a tensor either in a specific memory id, or both
59
+ a specific memory id and offset. For example, for operators that require
60
+ a specific memory id + offset for we can use this constraint to specify
61
+ location of inputs/outputs or even temporary buffers.
44
62
"""
45
63
46
64
def __init__ (
@@ -62,29 +80,38 @@ def __init__(
62
80
# A set of tensor spec ids that must be skipped during memory allocation.
63
81
# The exact mem_id and offset of the skipped tensors will be computed from
64
82
# the constraints.
65
- self ._source_node : dict [int , SourceInfo ] = {}
83
+ self ._relative_placement_constraint : dict [int , RelativePlacementConstraint ] = {}
66
84
67
85
# A map from `id(TensorSpec)` to a set of mem_ids that cannot be used for
68
86
# allocating the tensor.
69
87
self ._mem_id_blocklist : dict [int , set [int ]] = {}
70
88
71
- def get_source_info (self , node : torch .fx .Node ) -> Optional [SourceInfo ]:
89
+ # A map from `id(TensorSpec)` to a AbsolutePlacementConstraint that specifies mem_id and optionally exact offset.
90
+ self ._absolute_placement_constraints : dict [int , AbsolutePlacementConstraint ] = (
91
+ {}
92
+ )
93
+
94
+ def get_relative_placement_source (
95
+ self , node : torch .fx .Node
96
+ ) -> Optional [RelativePlacementConstraint ]:
72
97
spec = node .meta .get ("spec" )
73
98
spec_id = id (spec )
74
- if spec_id not in self ._source_node :
99
+ if spec_id not in self ._relative_placement_constraint :
75
100
return None
76
- return self ._source_node [spec_id ]
101
+ return self ._relative_placement_constraint [spec_id ]
77
102
78
- def set_source_info (
79
- self , dependent : torch .fx .Node , source_info : SourceInfo
103
+ def set_relative_placement_constraint (
104
+ self ,
105
+ dependent : torch .fx .Node ,
106
+ placement_constraint : RelativePlacementConstraint ,
80
107
) -> None :
81
108
dependent_spec = dependent .meta .get ("spec" )
82
109
spec_id = id (dependent_spec )
83
- self ._source_node [spec_id ] = source_info
84
- if self .is_memory_planned (source_info .source ):
110
+ self ._relative_placement_constraint [spec_id ] = placement_constraint
111
+ if self .is_memory_planned (placement_constraint .source ):
85
112
# Only add dependent nodes if source node needs memory planning.
86
113
self .unresolved_loc_constraints [
87
- id (source_info .source .meta .get ("spec" ))
114
+ id (placement_constraint .source .meta .get ("spec" ))
88
115
].add (dependent )
89
116
90
117
def add_mem_id_to_blocklist (self , spec : TensorSpec , mem_id : int ) -> None :
@@ -111,7 +138,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
111
138
node --> view
112
139
--> relu (or some other op that can be in-place)
113
140
"""
114
- if node_source_info := self .get_source_info (node ):
141
+ if node_source_info := self .get_relative_placement_source (node ):
115
142
node_spec = node .meta .get ("spec" )
116
143
node_source_spec = node_source_info .source .meta .get ("spec" )
117
144
return (
@@ -121,7 +148,7 @@ def is_alias_of(self, node: torch.fx.Node, other_node: torch.fx.Node) -> bool:
121
148
and self .is_alias_of (node_source_info .source , other_node )
122
149
)
123
150
124
- if self .get_source_info (other_node ) is not None :
151
+ if self .get_relative_placement_source (other_node ) is not None :
125
152
return self .is_alias_of (other_node , node )
126
153
127
154
return node == other_node
@@ -132,14 +159,14 @@ def relative_loc_constraints_exist(self) -> bool:
132
159
133
160
# Return true if the spec is marked as skipped
134
161
def skipped_spec (self , spec : TensorSpec ) -> bool :
135
- return id (spec ) in self ._source_node
162
+ return id (spec ) in self ._relative_placement_constraint
136
163
137
164
def is_memory_planned (
138
165
self ,
139
166
node : torch .fx .Node ,
140
167
) -> bool :
141
168
"""Return true if the node is either (1) a parameter, or (2) a placeholder."""
142
- if (source_info := self .get_source_info (node )) is not None :
169
+ if (source_info := self .get_relative_placement_source (node )) is not None :
143
170
# If node has relative placement constraints, then check the source.
144
171
return self .is_memory_planned (source_info .source )
145
172
# Check if any node is a param.
@@ -183,7 +210,7 @@ def resolve_relative_loc_constraints(self, spec: TensorSpec) -> None:
183
210
184
211
assert isinstance (spec , TensorSpec )
185
212
for dependent_node in self .unresolved_loc_constraints [spec_id ]:
186
- source_info = self .get_source_info (dependent_node )
213
+ source_info = self .get_relative_placement_source (dependent_node )
187
214
assert source_info is not None
188
215
dependent_spec = cast (TensorSpec , dependent_node .meta .get ("spec" ))
189
216
dependent_spec .mem_id = spec .mem_id
@@ -202,19 +229,21 @@ def update_children_nodes(self, node: torch.fx.Node, update_lifetime: bool) -> N
202
229
children_nodes = self .unresolved_loc_constraints [id (node .meta .get ("spec" ))]
203
230
self .unresolved_loc_constraints .pop (id (node .meta .get ("spec" )))
204
231
205
- source_info = self .get_source_info (node )
232
+ source_info = self .get_relative_placement_source (node )
206
233
assert source_info is not None
207
234
208
235
for child_node in children_nodes :
209
- child_info = self ._source_node .pop (id (child_node .meta .get ("spec" )))
210
- self .generate_location_constraint (
236
+ child_info = self ._relative_placement_constraint .pop (
237
+ id (child_node .meta .get ("spec" ))
238
+ )
239
+ self .add_relative_placement_constraint (
211
240
source_info .source ,
212
241
child_node ,
213
242
offset = source_info .offset + child_info .offset ,
214
243
update_lifetime = update_lifetime ,
215
244
)
216
245
217
- def generate_location_constraint (
246
+ def add_relative_placement_constraint (
218
247
self ,
219
248
source : torch .fx .Node ,
220
249
dependent : torch .fx .Node ,
@@ -230,29 +259,26 @@ def generate_location_constraint(
230
259
logging .debug (f"Adding constraint { dependent } = { source } + { offset = } " )
231
260
232
261
# Assert that both source and dependent node are tensors.
233
- if (info := self .get_source_info (source )) is not None :
234
- return self .generate_location_constraint (
235
- info .source , dependent , offset + info .offset , update_lifetime
236
- )
262
+ if (info := self .get_relative_placement_source (source )) is not None :
263
+ source = info .source
264
+ offset += info .offset
237
265
238
- if (info := self .get_source_info (dependent )) is not None :
266
+ if (info := self .get_relative_placement_source (dependent )) is not None :
239
267
# Dependent node can only be an alias (same size, offset = 0).
240
268
assert self .is_alias_of (
241
269
info .source , dependent
242
270
), f"Multiple constraints for allocation of { dependent } . Previous constraint: { info } new constraint: { source = } { offset = } "
243
- return self .generate_location_constraint (
244
- source , info .source , offset , update_lifetime = update_lifetime
245
- )
271
+ dependent = info .source
246
272
247
273
# Add the dependent spec to skip list. Its memory offset will be computed
248
274
# after the output tensor is allocated space.
249
- source_info = SourceInfo (source = source , offset = offset )
250
- self .set_source_info (dependent , source_info )
275
+ source_info = RelativePlacementConstraint (source = source , offset = offset )
276
+ self .set_relative_placement_constraint (dependent , source_info )
251
277
252
278
# If update_lifetime is True, take a union of the lifetime of representaitve
253
279
# and dependent tensors; this will become the new lifetime of source tensor.
280
+ dependent_spec = dependent .meta .get ("spec" )
254
281
if update_lifetime :
255
- dependent_spec = dependent .meta .get ("spec" )
256
282
source_spec = source .meta .get ("spec" )
257
283
source .meta .get ("spec" ).lifetime = [
258
284
min (source_spec .lifetime [0 ], dependent_spec .lifetime [0 ]),
@@ -261,6 +287,49 @@ def generate_location_constraint(
261
287
262
288
self .update_children_nodes (dependent , update_lifetime )
263
289
290
+ abs_constraint = self .get_absolute_placement_constraint (dependent_spec )
291
+ if abs_constraint is None :
292
+ return
293
+
294
+ # Dependent node has an absolute placement constraint.
295
+ # If the offset is not 0, then we cannot add a relative placement constraint.
296
+ if not self .is_alias_of (dependent , source ):
297
+ raise RuntimeError (
298
+ f"Cannot add relative placement constraint for { dependent } with non-zero offset { offset } when it has an absolute placement constraint { abs_constraint } "
299
+ )
300
+
301
+ # Add the absolute placement constraint to the source node.
302
+ self ._absolute_placement_constraints .pop (id (dependent_spec ))
303
+ self .add_absolute_placement_constraint (
304
+ source , abs_constraint .pinned_memory_id , abs_constraint .offset
305
+ )
306
+
307
+ def add_absolute_placement_constraint (
308
+ self , node : torch .fx .Node , pinned_memory_id : int , offset : Optional [int ] = None
309
+ ) -> None :
310
+ """Add a memory pinning constraint for `node` to `mem_id`."""
311
+ logging .debug (
312
+ f"Adding memory pinning constraint { node = } = { pinned_memory_id = } at { offset = } "
313
+ )
314
+ source_node : torch .fx .Node = node
315
+ if (info := self .get_relative_placement_source (node )) is not None :
316
+ assert self .is_alias_of (info .source , node )
317
+ logging .debug (
318
+ f"Setting { node } to { info .source } + { offset = } . Pinned to { pinned_memory_id = } "
319
+ )
320
+ source_node = info .source
321
+ self ._absolute_placement_constraints [id (source_node .meta .get ("spec" ))] = (
322
+ AbsolutePlacementConstraint (
323
+ pinned_memory_id = pinned_memory_id , offset = offset
324
+ )
325
+ )
326
+
327
+ def get_absolute_placement_constraint (
328
+ self , spec : TensorSpec
329
+ ) -> Optional [AbsolutePlacementConstraint ]:
330
+ """Return true if `node` has an absolute placement constraint."""
331
+ return self ._absolute_placement_constraints .get (id (spec ), None )
332
+
264
333
265
334
def get_relative_offsets_of_cat_tensors (
266
335
cat_tensors : Sequence [torch .fx .Node ],
@@ -342,7 +411,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
342
411
343
412
def is_slice_view (self , node : torch .fx .Node ) -> bool :
344
413
"""Return if `node` has constraints and is not an alias of another node."""
345
- if (source_info := self .constraint .get_source_info (node )) is not None :
414
+ if (
415
+ source_info := self .constraint .get_relative_placement_source (node )
416
+ ) is not None :
346
417
return not self .constraint .is_alias_of (source_info .source , node )
347
418
return False
348
419
@@ -426,7 +497,9 @@ def is_removable_cat_op(
426
497
return True
427
498
428
499
# Currently the contiguity constraints are generated by cat operator.
429
- def compute_cat_contiguity_constraints (self , graph_module : torch .fx .GraphModule ):
500
+ def compute_cat_contiguity_constraints (
501
+ self , graph_module : torch .fx .GraphModule
502
+ ) -> None :
430
503
for node in graph_module .graph .nodes :
431
504
# Only compute relative constraints if the cat node can be replaced with
432
505
# its nop version
@@ -448,7 +521,9 @@ def compute_cat_contiguity_constraints(self, graph_module: torch.fx.GraphModule)
448
521
# Get the relative offsets for each tensor to be concatenated.
449
522
relative_offsets = get_relative_offsets_of_cat_tensors (cat_tensors )
450
523
for arg , offset in zip (cat_tensors , relative_offsets ):
451
- self .constraint .generate_location_constraint (node , arg , offset = offset )
524
+ self .constraint .add_relative_placement_constraint (
525
+ node , arg , offset = offset
526
+ )
452
527
453
528
# Update the lifetimes of the args to that of the output tensor, so
454
529
# that they don't get overwritten
@@ -474,7 +549,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> Optional[PassResult]:
474
549
for node in graph_module .graph .nodes :
475
550
if node .op != "call_function" or node .target != memory .view :
476
551
continue
477
- self .constraint .generate_location_constraint (node .args [0 ], node )
552
+ self .constraint .add_relative_placement_constraint (node .args [0 ], node )
478
553
479
554
480
555
@register_cadence_pass (CadencePassAttribute (opt_level = 2 ))
@@ -544,7 +619,7 @@ def removable_slice_or_select_op(
544
619
# the input and output tensor.
545
620
def compute_slice_and_select_loc_constraints (
546
621
self , graph_module : torch .fx .GraphModule
547
- ):
622
+ ) -> None :
548
623
for node in graph_module .graph .nodes :
549
624
# Only compute relative constraints if the slice node can be
550
625
# replaced with its nop version
@@ -563,7 +638,7 @@ def compute_slice_and_select_loc_constraints(
563
638
# And now generate location constraint between input and output
564
639
# tensors of slice node
565
640
arg = node .args [0 ]
566
- self .constraint .generate_location_constraint (
641
+ self .constraint .add_relative_placement_constraint (
567
642
arg ,
568
643
node ,
569
644
offset = offset ,
@@ -607,12 +682,7 @@ def __call__(self, graph_module: torch.fx.GraphModule) -> PassResult:
607
682
filtered_passes = [
608
683
mcg_pass (self .mem_constraints )
609
684
for mcg_pass in cast (
610
- list [
611
- typing .Callable [
612
- [MemConstraints ],
613
- typing .Callable [[torch .fx .GraphModule ], Optional [PassResult ]],
614
- ]
615
- ],
685
+ list [ConstraintsGenPass ],
616
686
# pyre-ignore[6]: Incompatible parameter type.
617
687
list (filter (pass_filter , constraint_gen_passes )),
618
688
)
0 commit comments