Skip to content

Commit d92fa11

Browse files
committed
Add preserve_ops to EdgeCompileConfig
1. Add `preserve_ops` to `EdgeCompileConfig` 2. Remove preserved ops from the decomposition table in `to_edge`. 3. Add checks to the verifier ensuring that preserved ops do not have mutations or views. 4. Update 'core_aten_exception_list' to be 'preserved_ops' in `to_edge_transform_and_lower`. Context/Usage **core_aten_ops_exception_list** - Contains operators that are missing a decomposition to core aten. - Exclude these so that verification can still be run on the rest of the graph. - Ideally, this list should be empty. **preserve_ops** - Contains operators that the user specifically does not want decomposed. - Must be aten; custom ops are ignored by verifier. Edge case: - If an aten operator does not have a decomp, and the user specifically wants it to be preserved, put it in preserve_ops rather than core_aten_ops_exception_list. Differential Revision: [D78298749](https://our.internmc.facebook.com/intern/diff/D78298749/) [ghstack-poisoned]
1 parent 1540659 commit d92fa11

File tree

5 files changed

+112
-37
lines changed

5 files changed

+112
-37
lines changed

exir/capture/_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,14 @@ class EdgeCompileConfig:
4040
# TODO(larryliu): remove this
4141
_use_edge_ops: bool = True
4242
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
43+
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
44+
# enabling verification on the rest of the program until decomposition coverage is improved.
4345
_core_aten_ops_exception_list: List[torch._ops.OpOverload] = field(
4446
default_factory=list
4547
)
48+
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
49+
# These may be core or non-core ATen ops; custom ops should not be here.
50+
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
4651
# TODO(gasoonjia): remove this
4752
_skip_dim_order: bool = False
4853

exir/program/_program.py

Lines changed: 39 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -795,9 +795,19 @@ def _generate_edge_program(
795795
name: str,
796796
config: EdgeCompileConfig,
797797
program: ExportedProgram,
798-
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
798+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
799+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
799800
) -> ExportedProgram:
800-
801+
"""
802+
Args:
803+
name: The name of the program.
804+
config: The configuration for the edge program.
805+
program: The exported program to be converted to an edge program.
806+
core_aten_ops_exception_list: A list of aten ops that are missing decompositions to core aten.
807+
preserve_ops: A list of aten ops that should not be decomposed.
808+
Returns:
809+
An ExportedProgram in edge dialect.
810+
"""
801811
# Remove invalid assert ops, such as _assert_tensor_metadata
802812
gm = program.graph_module
803813
gm_res = RemoveNonCoreAtenOpGraphAssertsPass()(gm)
@@ -812,7 +822,8 @@ def _generate_edge_program(
812822
EXIRATenDialectVerifier(
813823
edge_compile_config=config,
814824
class_only=False,
815-
exception_list=ops_set_to_not_decompose,
825+
core_aten_ops_exception_list=core_aten_ops_exception_list,
826+
preserve_ops=preserve_ops,
816827
)(gm)
817828
except ExportError as e:
818829
logging.info(f"Input program {name} is not in ATen dialect.")
@@ -848,7 +859,8 @@ def _generate_edge_program(
848859
EXIREdgeDialectVerifier(
849860
edge_compile_config=config,
850861
class_only=True,
851-
exception_list=ops_set_to_not_decompose,
862+
core_aten_ops_exception_list=core_aten_ops_exception_list,
863+
preserve_ops=preserve_ops,
852864
)
853865
],
854866
)
@@ -864,7 +876,7 @@ def _replace_aten_ops_with_transformed_ops(
864876
program: ExportedProgram,
865877
partitioner,
866878
):
867-
ops_to_not_decompose = set()
879+
preserve_ops = set()
868880
partitioners = partitioner.get(name)
869881
if partitioners is None:
870882
return
@@ -889,7 +901,7 @@ def _replace_aten_ops_with_transformed_ops(
889901
and node.target in ops_set_to_not_decompose
890902
and is_op_supported
891903
):
892-
ops_to_not_decompose.add(node.target)
904+
preserve_ops.add(node.target)
893905
node.target = aten_op_to_transform_op[node.target]
894906

895907
for _, submod, _ in get_control_flow_submodules(program.graph_module):
@@ -900,10 +912,10 @@ def _replace_aten_ops_with_transformed_ops(
900912
and node.target in ops_set_to_not_decompose
901913
and is_op_supported
902914
):
903-
ops_to_not_decompose.add(node.target)
915+
preserve_ops.add(node.target)
904916
node.target = aten_op_to_transform_op[node.target]
905917

906-
return ops_to_not_decompose
918+
return preserve_ops
907919

908920

909921
def _restore_transformed_ops_to_aten_ops(program: ExportedProgram):
@@ -1014,7 +1026,7 @@ def _sanity_check_graph_for_non_decomp_ops(
10141026

10151027

10161028
def _remove_invalid_ops_for_not_decompose(
1017-
ops_to_not_decompose: List[torch._ops.OpOverload],
1029+
preserve_ops: List[torch._ops.OpOverload],
10181030
) -> List[torch._ops.OpOverload]:
10191031
_logged_warnings = set()
10201032

@@ -1079,7 +1091,7 @@ def keep(op):
10791091
return False
10801092
return True
10811093

1082-
return list(filter(keep, ops_to_not_decompose))
1094+
return list(filter(keep, preserve_ops))
10831095

10841096

10851097
def _gen_edge_manager_for_partitioners(
@@ -1136,7 +1148,7 @@ def _gen_edge_manager_for_partitioners(
11361148
name,
11371149
config,
11381150
program,
1139-
list(ops_set_to_not_decompose_by_program.get(name, [])),
1151+
preserve_ops=list(ops_set_to_not_decompose_by_program.get(name, [])),
11401152
)
11411153

11421154
edge_manager = EdgeProgramManager(
@@ -1281,7 +1293,7 @@ def to_edge_transform_and_lower(
12811293
EXIREdgeDialectVerifier(
12821294
edge_compile_config=config,
12831295
class_only=True,
1284-
exception_list=list(ops_set_to_not_decompose),
1296+
preserve_ops=list(ops_set_to_not_decompose),
12851297
)()(program.graph_module)
12861298

12871299
return edge_manager
@@ -1328,7 +1340,7 @@ def to_edge_with_preserved_ops(
13281340
table.pop(op, None)
13291341
program = program.run_decompositions(table)
13301342
edge_programs[name] = _generate_edge_program(
1331-
name, config, program, list(preserve_ops)
1343+
name, config, program, preserve_ops=list(preserve_ops)
13321344
)
13331345

13341346
return EdgeProgramManager(
@@ -1367,8 +1379,16 @@ def to_edge(
13671379

13681380
for name, program in aten_programs.items():
13691381
# Decompose to Core ATen
1370-
program = program.run_decompositions(_default_decomposition_table())
1371-
edge_programs[name] = _generate_edge_program(name, config, program)
1382+
table = _default_decomposition_table()
1383+
preserve_ops = []
1384+
if compile_config:
1385+
preserve_ops = compile_config._preserve_ops
1386+
for op in compile_config._preserve_ops:
1387+
table.pop(op, None)
1388+
program = program.run_decompositions(table)
1389+
edge_programs[name] = _generate_edge_program(
1390+
name, config, program, preserve_ops=preserve_ops
1391+
)
13721392

13731393
return EdgeProgramManager(edge_programs, constant_methods, config)
13741394

@@ -1389,7 +1409,8 @@ def __init__(
13891409
edge_programs: Union[ExportedProgram, Dict[str, ExportedProgram]],
13901410
constant_methods: Optional[Dict[str, Any]] = None,
13911411
compile_config: Optional[EdgeCompileConfig] = None,
1392-
ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None,
1412+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
1413+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
13931414
):
13941415
"""
13951416
Should not be called directly by users. User should use :func:'to_edge' instead.
@@ -1404,7 +1425,8 @@ def __init__(
14041425
try:
14051426
EXIREdgeDialectVerifier(
14061427
edge_compile_config=self.compile_config,
1407-
exception_list=ops_set_to_not_decompose,
1428+
core_aten_ops_exception_list=core_aten_ops_exception_list,
1429+
preserve_ops=preserve_ops,
14081430
)(program.graph_module)
14091431
except ExportError as e:
14101432
logging.info(f"Input program {name} is not in aten dialect.")

exir/program/test/test_program.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
ExecutorchProgramManager,
2828
to_edge,
2929
to_edge_transform_and_lower,
30-
to_edge_with_preserved_ops,
3130
)
3231
from executorch.exir.tracer import _default_decomposition_table
3332
from executorch.exir.verification.verifier import EXIREdgeDialectVerifier
@@ -784,7 +783,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
784783
def _test_to_edge_with_preserved_ops(
785784
self, program, preserved_ops, expected_preserved_ops
786785
):
787-
edge = to_edge_with_preserved_ops(program, preserve_ops=preserved_ops)
786+
edge = to_edge(
787+
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
788+
)
788789

789790
def count_nodes(graph_module, target):
790791
count = 0

exir/verification/test/test_verifier.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,3 +161,17 @@ def forward(self, input, label):
161161
edge_verifier = EXIREdgeDialectVerifier()
162162

163163
edge_verifier(edge.exported_program())
164+
165+
def test_verifier_preserve_ops_view(self) -> None:
166+
class TestExpand(nn.Module):
167+
def __init__(self):
168+
super().__init__()
169+
170+
def forward(self, x):
171+
return x.expand(2, 2, 2, 2)
172+
173+
model = TestExpand()
174+
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
175+
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
176+
with self.assertRaises(RuntimeError):
177+
to_edge(export_model, compile_config=config)

exir/verification/verifier.py

Lines changed: 51 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -81,26 +81,33 @@ def __call__(self, *args, **kwargs):
8181
def EXIRATenDialectVerifier( # noqa: C901
8282
edge_compile_config: Optional[EdgeCompileConfig] = None,
8383
class_only: bool = False,
84-
exception_list: Optional[List[torch._ops.OpOverload]] = None,
84+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
85+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
8586
):
8687
"""
8788
Returns a verifier class that runs ATen dialect specific checks on the graph module.
8889
"""
90+
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
91+
_preserve_ops = preserve_ops or []
8992
# merge the exception list from edge_compile_config and exception_list
90-
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
91-
exception_list = edge_compile_config._core_aten_ops_exception_list + (
92-
exception_list or []
93-
)
93+
if edge_compile_config:
94+
if edge_compile_config._core_aten_ops_exception_list:
95+
_core_aten_ops_exception_list.extend(
96+
edge_compile_config._core_aten_ops_exception_list
97+
)
98+
if edge_compile_config._preserve_ops:
99+
_preserve_ops.extend(edge_compile_config._preserve_ops)
94100

95101
class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
96102
dialect = "OLD_EXIR_ATEN"
97103

98104
def __init__(self) -> None:
99105
super().__init__()
100106
# Note: here we are using the exception list passed from EXIRATenDialectVerifier function!
101-
self._exception_list = exception_list if exception_list else []
107+
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
108+
self._preserve_ops = _preserve_ops
102109

103-
def _get_exception_list(self) -> List[torch._ops.OpOverload]:
110+
def _get_core_aten_ops_exception_list(self) -> List[torch._ops.OpOverload]:
104111
exception_list = (
105112
[
106113
torch.ops.aten.mkldnn_rnn_layer.default,
@@ -113,15 +120,30 @@ def _get_exception_list(self) -> List[torch._ops.OpOverload]:
113120
]
114121
+ list(_EXECUTORCH_SYM_OPS)
115122
+ DISALLOW_LIST
116-
+ self._exception_list
123+
+ self._core_aten_ops_exception_list
117124
)
118125

119126
return exception_list
120127

121128
def check_valid_op(self, op):
122129
if isinstance(op, OpOverload):
123130
# TODO These special ops should be removable easily.
124-
if op.namespace != "aten" or op in self._get_exception_list():
131+
if (
132+
op.namespace != "aten"
133+
or op in self._get_core_aten_ops_exception_list()
134+
):
135+
return
136+
if op in self._preserve_ops:
137+
# Preserved ops should not include mutation or view,
138+
# which may affect memory planning.
139+
if op._schema.is_mutable or op.is_view:
140+
raise RuntimeError(
141+
f"Cannot preserve operator {op} because it is a view or mutation."
142+
)
143+
if op.namespace != "aten":
144+
raise RuntimeError(
145+
f"Only preserve aten ops. Received op {op} with namespace {op.namespace}."
146+
)
125147
return
126148
if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags:
127149
# NOTE(qihan): whether view_copy operators are marked as canonical is still under
@@ -149,7 +171,9 @@ def check_valid_op(self, op):
149171
def get_aten_verifier(config: EdgeCompileConfig):
150172
return (
151173
EXIRATenDialectVerifier(
152-
class_only=True, exception_list=config._core_aten_ops_exception_list
174+
class_only=True,
175+
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
176+
preserve_ops=config._preserve_ops,
153177
)
154178
if config._check_ir_validity
155179
else EXIRATenDialectVerifierBase
@@ -210,13 +234,19 @@ def _check_tensor_args_matching_op_allowed_dtype(gm: GraphModule) -> None:
210234
def EXIREdgeDialectVerifier( # noqa: C901
211235
edge_compile_config: Optional[EdgeCompileConfig] = None,
212236
class_only: bool = False,
213-
exception_list: Optional[List[torch._ops.OpOverload]] = None,
237+
core_aten_ops_exception_list: Optional[List[torch._ops.OpOverload]] = None,
238+
preserve_ops: Optional[List[torch._ops.OpOverload]] = None,
214239
):
240+
_core_aten_ops_exception_list = core_aten_ops_exception_list or []
241+
_preserve_ops = preserve_ops or []
215242
# merge the exception list from edge_compile_config and exception_list
216-
if edge_compile_config and edge_compile_config._core_aten_ops_exception_list:
217-
exception_list = edge_compile_config._core_aten_ops_exception_list + (
218-
exception_list or []
219-
)
243+
if edge_compile_config:
244+
if edge_compile_config._core_aten_ops_exception_list:
245+
_core_aten_ops_exception_list.extend(
246+
edge_compile_config._core_aten_ops_exception_list
247+
)
248+
if edge_compile_config._preserve_ops:
249+
_preserve_ops.extend(edge_compile_config._preserve_ops)
220250

221251
class _EXIREdgeDialectVerifier(Verifier):
222252
dialect = "EDGE"
@@ -228,16 +258,19 @@ def __init__(self) -> None:
228258
self.check_edge_ops = _edge_compile_config._use_edge_ops
229259
self.use_dim_order = not _edge_compile_config._skip_dim_order
230260

261+
self._core_aten_ops_exception_list = _core_aten_ops_exception_list
262+
self._preserve_ops = _preserve_ops
263+
231264
self.aten_op_verifier = EXIRATenDialectVerifier(
232-
exception_list=exception_list
265+
core_aten_ops_exception_list=_core_aten_ops_exception_list,
266+
preserve_ops=_preserve_ops,
233267
)
234268
self.check_valid_aten_op = self.aten_op_verifier.check_valid_op
235269

236270
if self.check_edge_ops:
237271
self.check_valid_op = self.check_valid_edge_op
238272
else:
239273
self.check_valid_op = self.check_valid_aten_op
240-
self._exception_list = exception_list if exception_list else []
241274

242275
def allowed_getattr_types(self) -> Tuple[Type[Any], ...]:
243276
return (
@@ -258,7 +291,7 @@ def check_valid_edge_op(self, op):
258291
in [operator.getitem]
259292
+ DISALLOW_LIST
260293
+ list(_EXECUTORCH_SYM_OPS)
261-
+ self._exception_list
294+
+ self._core_aten_ops_exception_list
262295
):
263296
return
264297

0 commit comments

Comments
 (0)