Skip to content

Commit bb7932f

Browse files
committed
[Task] : Specify legal backend ops for tosa/linalg pipeline to avoid decomposing those ops.
1 parent 16b7e33 commit bb7932f

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

python/torch_mlir/fx_mw.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
from . import fx
1515
from torch._decomp import get_decompositions
1616

17-
1817
def import_exported_model(
1918
prog: torch.export.ExportedProgram,
2019
output_type: str,
@@ -24,18 +23,44 @@ def import_exported_model(
2423
decomp_table = get_decompositions(
2524
[torch.ops.aten.lstm.input, torch.ops.aten.gru.input]
2625
)
27-
2826
prog = prog.run_decompositions(decomp_table)
2927

28+
backend_legal_ops = None
29+
3030
match output_type:
3131
case "torch":
3232
output_type = OutputType.TORCH
3333
case "tosa":
3434
output_type = OutputType.TOSA
35+
backend_legal_ops = [
36+
"aten.flatten.using_ints",
37+
"aten.native_layer_norm",
38+
"aten.adaptive_avg_pool1d",
39+
"aten.adaptive_avg_pool2d",
40+
"aten.adaptive_max_pool1d",
41+
"aten.adaptive_max_pool2d",
42+
"aten.linear"]
3543
case "linalg_on_tensors":
3644
output_type = OutputType.LINALG_ON_TENSORS
45+
backend_legal_ops = [
46+
"aten.flatten.using_ints",
47+
"aten.adaptive_avg_pool1d",
48+
"aten.adaptive_avg_pool2d",
49+
"aten.adaptive_max_pool1d",
50+
"aten.adaptive_max_pool2d",
51+
"aten.unflatten.int",
52+
]
3753
case "tosa_linalg":
3854
output_type = OutputType.TOSA_LINALG
55+
backend_legal_ops = [
56+
"aten.flatten.using_ints",
57+
"aten.native_layer_norm",
58+
"aten.adaptive_avg_pool1d",
59+
"aten.adaptive_avg_pool2d",
60+
"aten.adaptive_max_pool1d",
61+
"aten.adaptive_max_pool2d",
62+
"aten.linear",
63+
"aten.unflatten.int"]
3964
case "raw":
4065
output_type = OutputType.RAW
4166
case _:
@@ -49,6 +74,12 @@ def import_exported_model(
4974

5075
if output_type != OutputType.RAW:
5176
backend_legal_op_arg_str = ""
77+
if backend_legal_ops is not None:
78+
if not len(backend_legal_ops) == 0:
79+
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
80+
backend_legal_ops
81+
)
82+
5283
extra_library_file_name = ""
5384
option_string = (
5485
"{"

0 commit comments

Comments
 (0)