Skip to content

Commit ff825e4

Browse files
authored
Remove usages of 'to_edge_with_preserve_ops'
Differential Revision: D78311705 Pull Request resolved: #12471
1 parent c8d898d commit ff825e4

File tree

7 files changed

+26
-26
lines changed

7 files changed

+26
-26
lines changed

backends/cadence/aot/compiler.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
)
3535
from executorch.exir.passes import ToOutVarPass
3636
from executorch.exir.passes.sym_shape_eval_pass import HintBasedSymShapeEvalPass
37-
from executorch.exir.program._program import to_edge_with_preserved_ops
37+
from executorch.exir.program._program import to_edge
3838
from torch._inductor.decomposition import remove_decompositions
3939

4040
from torch.export.exported_program import ExportedProgram
@@ -219,9 +219,9 @@ def quantize_pt2(
219219
torch.ops.aten.angle.default,
220220
torch.ops.aten.rms_norm.default,
221221
]
222-
TO_EDGE_PRESERVE_OPS: tuple[torch._ops.OpOverload, ...] = (
222+
TO_EDGE_PRESERVE_OPS: list[torch._ops.OpOverload, ...] = [
223223
torch.ops.aten.rms_norm.default,
224-
)
224+
]
225225

226226

227227
def _lower_ep_to_edge(
@@ -233,18 +233,18 @@ def _lower_ep_to_edge(
233233
"""
234234
Lower an ExportedProgram to an EdgeProgramManager (in edge IR).
235235
"""
236-
# Call to_edge_with_preserved_ops to convert the graph to edge IR.
236+
# Call to_edge to convert the graph to edge IR.
237237
# Note: dim_order is skipped (https://github.com/pytorch/executorch/issues/3704)
238-
edge_prog_manager = to_edge_with_preserved_ops(
238+
edge_prog_manager = to_edge(
239239
expo_program,
240240
compile_config=EdgeCompileConfig(
241241
_skip_dim_order=True,
242242
# Allow specific non-core aten ops in the IR.
243243
_core_aten_ops_exception_list=TO_EDGE_OP_EXCEPTION_LIST
244244
+ (core_aten_exceptions or []),
245+
preserve_ops=TO_EDGE_PRESERVE_OPS,
245246
),
246247
constant_methods=constant_methods,
247-
preserve_ops=TO_EDGE_PRESERVE_OPS,
248248
)
249249

250250
if dump_graphs:

examples/apple/coreml/llama/export.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from executorch.exir.passes import MemoryPlanningPass
2828
from executorch.exir.passes.quant_fusion_pass import QuantFusionPass
2929
from executorch.exir.passes.sym_shape_eval_pass import ConstraintBasedSymShapeEvalPass
30-
from executorch.exir.program._program import to_edge_with_preserved_ops
30+
from executorch.exir.program._program import to_edge
3131
from executorch.extension.export_util.utils import save_pte_program
3232

3333

@@ -196,17 +196,17 @@ def main() -> None:
196196
print("Exported program")
197197
print(ep)
198198

199-
edge_manager = to_edge_with_preserved_ops(
199+
edge_manager = to_edge(
200200
ep,
201-
preserve_ops=[
202-
torch.ops.aten.scaled_dot_product_attention.default,
203-
# preserve norm op for numerical stability
204-
torch.ops.aten.linalg_vector_norm.default,
205-
torch.ops.aten.reciprocal.default,
206-
],
207201
compile_config=EdgeCompileConfig(
208202
_check_ir_validity=False,
209203
_skip_dim_order=True,
204+
preserve_ops=[
205+
torch.ops.aten.scaled_dot_product_attention.default,
206+
# preserve norm op for numerical stability
207+
torch.ops.aten.linalg_vector_norm.default,
208+
torch.ops.aten.reciprocal.default,
209+
],
210210
),
211211
)
212212
print("Edge program")

exir/capture/_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ class EdgeCompileConfig:
3939
_check_ir_validity: bool = True
4040
# TODO(larryliu): remove this
4141
_use_edge_ops: bool = True
42+
# TODO(gasoonjia): remove this
43+
_skip_dim_order: bool = False
4244
# Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks.
4345
# Note: only use this for core ATen ops that are missing decompositions. This is temporary,
4446
# enabling verification on the rest of the program until decomposition coverage is improved.
@@ -47,9 +49,7 @@ class EdgeCompileConfig:
4749
)
4850
# Allow ops to be preserved in the graph, i.e., prevent them from being decomposed.
4951
# These may be core or non-core ATen ops; custom ops should not be here.
50-
_preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
51-
# TODO(gasoonjia): remove this
52-
_skip_dim_order: bool = False
52+
preserve_ops: List[torch.torch._ops.OpOverload] = field(default_factory=list)
5353

5454

5555
@compatibility(is_backward_compatible=False)

exir/program/_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1382,8 +1382,8 @@ def to_edge(
13821382
table = _default_decomposition_table()
13831383
preserve_ops = []
13841384
if compile_config:
1385-
preserve_ops = compile_config._preserve_ops
1386-
for op in compile_config._preserve_ops:
1385+
preserve_ops = compile_config.preserve_ops
1386+
for op in compile_config.preserve_ops:
13871387
table.pop(op, None)
13881388
program = program.run_decompositions(table)
13891389
edge_programs[name] = _generate_edge_program(

exir/program/test/test_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -784,7 +784,7 @@ def _test_to_edge_with_preserved_ops(
784784
self, program, preserved_ops, expected_preserved_ops
785785
):
786786
edge = to_edge(
787-
program, compile_config=EdgeCompileConfig(_preserve_ops=preserved_ops)
787+
program, compile_config=EdgeCompileConfig(preserve_ops=preserved_ops)
788788
)
789789

790790
def count_nodes(graph_module, target):

exir/verification/test/test_verifier.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def forward(self, x):
171171
return x.expand(2, 2, 2, 2)
172172

173173
model = TestExpand()
174-
config = EdgeCompileConfig(_preserve_ops=[torch.ops.aten.expand.default])
174+
config = EdgeCompileConfig(preserve_ops=[torch.ops.aten.expand.default])
175175
export_model = export(model, (torch.randn(2, 2, 2, 2),), strict=True)
176176
with self.assertRaises(RuntimeError):
177177
to_edge(export_model, compile_config=config)

exir/verification/verifier.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ def EXIRATenDialectVerifier( # noqa: C901
9898
_core_aten_ops_exception_list.extend(
9999
edge_compile_config._core_aten_ops_exception_list
100100
)
101-
if edge_compile_config._preserve_ops:
102-
_preserve_ops.extend(edge_compile_config._preserve_ops)
101+
if edge_compile_config.preserve_ops:
102+
_preserve_ops.extend(edge_compile_config.preserve_ops)
103103

104104
class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase):
105105
dialect = "OLD_EXIR_ATEN"
@@ -181,7 +181,7 @@ def get_aten_verifier(config: EdgeCompileConfig):
181181
EXIRATenDialectVerifier(
182182
class_only=True,
183183
core_aten_ops_exception_list=config._core_aten_ops_exception_list,
184-
preserve_ops=config._preserve_ops,
184+
preserve_ops=config.preserve_ops,
185185
)
186186
if config._check_ir_validity
187187
else EXIRATenDialectVerifierBase
@@ -253,8 +253,8 @@ def EXIREdgeDialectVerifier( # noqa: C901
253253
_core_aten_ops_exception_list.extend(
254254
edge_compile_config._core_aten_ops_exception_list
255255
)
256-
if edge_compile_config._preserve_ops:
257-
_preserve_ops.extend(edge_compile_config._preserve_ops)
256+
if edge_compile_config.preserve_ops:
257+
_preserve_ops.extend(edge_compile_config.preserve_ops)
258258

259259
class _EXIREdgeDialectVerifier(Verifier):
260260
dialect = "EDGE"

0 commit comments

Comments
 (0)