Skip to content

Commit 1dc89f4

Browse files
committed
[Task] : Add a separate pass pipeline to not dump mlir files in case of failure.
1 parent ad83c6d commit 1dc89f4

File tree

3 files changed

+149
-0
lines changed

3 files changed

+149
-0
lines changed

python/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI
4646
compiler_utils.py
4747
fx.py
4848
extras/fx_decomp_util.py
49+
compiler_utils_mw.py
50+
fx_mw.py
4951
)
5052

5153
declare_mlir_python_sources(TorchMLIRPythonSources.Tools
Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
from enum import Enum
6+
from io import StringIO
7+
import os
8+
import sys
9+
import tempfile
10+
from typing import Union, List
11+
12+
import torch
13+
from .passmanager import PassManager
14+
from .ir import StringAttr
15+
16+
from torch_mlir.compiler_utils import OutputType
17+
18+
19+
def run_pipeline_mw(
20+
module, pipeline: str, description: str, enable_ir_printing: bool = False
21+
):
22+
"""Runs `pipeline` on `module`"""
23+
with module.context as ctx:
24+
# TODO(#3506): Passes can emit errors but not signal failure,
25+
# which causes a native assert.
26+
ctx.emit_error_diagnostics = False
27+
pm = PassManager.parse(pipeline)
28+
if enable_ir_printing:
29+
ctx.enable_multithreading(False)
30+
pm.enable_ir_printing()
31+
pm.run(module.operation)
32+
33+
34+
def lower_mlir_module_mw(verbose, output_type, module):
35+
if verbose:
36+
print("\n====================")
37+
print("Torch Backend IR")
38+
print(module)
39+
40+
if output_type == OutputType.TORCH:
41+
return module
42+
43+
if output_type == OutputType.TOSA:
44+
run_pipeline_mw(
45+
module,
46+
"builtin.module(torch-backend-to-tosa-backend-pipeline)",
47+
"Lowering Torch Backend IR -> TOSA Backend IR",
48+
)
49+
if verbose:
50+
print("\n====================")
51+
print("TOSA Backend IR")
52+
print(module)
53+
return module
54+
55+
if output_type == OutputType.LINALG_ON_TENSORS:
56+
run_pipeline_mw(
57+
module,
58+
"builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)",
59+
"Lowering Torch Backend IR -> Linalg-on-Tensors Backend IR",
60+
)
61+
if verbose:
62+
print("\n====================")
63+
print("LINALG Backend IR")
64+
print(module)
65+
return module
66+
67+
elif output_type == OutputType.TOSA_LINALG:
68+
run_pipeline_mw(
69+
module,
70+
"builtin.module(torch-backend-to-tosa-linalg-backend-pipeline)",
71+
"Lowering Torch Backend IR -> TOSA_LINALG Backend IR",
72+
)
73+
if verbose:
74+
print("\n====================")
75+
print("TODA_LINALG Backend IR")
76+
print(module)
77+
return module
78+
raise Exception(f"Unknown OutputType: {output_type}")

python/torch_mlir/fx_mw.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
2+
# See https://llvm.org/LICENSE.txt for license information.
3+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
4+
# Also available under a BSD-style license. See LICENSE.
5+
6+
import torch
7+
from .compiler_utils import OutputType
8+
9+
from .compiler_utils_mw import (
10+
run_pipeline_mw,
11+
lower_mlir_module_mw,
12+
)
13+
14+
from . import fx
15+
from torch._decomp import get_decompositions
16+
17+
18+
def import_exported_model(
19+
prog: torch.export.ExportedProgram,
20+
output_type: str,
21+
experimental_support_mutation: bool = True,
22+
):
23+
24+
decomp_table = get_decompositions(
25+
[torch.ops.aten.lstm.input, torch.ops.aten.gru.input]
26+
)
27+
28+
prog = prog.run_decompositions(decomp_table)
29+
30+
match output_type:
31+
case "torch":
32+
output_type = OutputType.TORCH
33+
case "tosa":
34+
output_type = OutputType.TOSA
35+
case "linalg_on_tensors":
36+
output_type = OutputType.LINALG_ON_TENSORS
37+
case "tosa_linalg":
38+
output_type = OutputType.TOSA_LINALG
39+
case "raw":
40+
output_type = OutputType.RAW
41+
case _:
42+
raise ValueError("Importing PyTorch model failed: Unsupported output type.")
43+
44+
mlir_module = fx.export_and_import(
45+
prog,
46+
output_type=OutputType.RAW,
47+
experimental_support_mutation=experimental_support_mutation,
48+
)
49+
50+
if output_type != OutputType.RAW:
51+
backend_legal_op_arg_str = ""
52+
extra_library_file_name = ""
53+
option_string = (
54+
"{"
55+
+ backend_legal_op_arg_str
56+
+ " extra-library="
57+
+ extra_library_file_name
58+
+ "}"
59+
)
60+
run_pipeline_mw(
61+
mlir_module,
62+
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
63+
"Lowering TorchFX IR -> Torch Backend IR",
64+
enable_ir_printing=False,
65+
)
66+
verbose = False
67+
mlir_module = lower_mlir_module_mw(verbose, output_type, mlir_module)
68+
69+
return mlir_module

0 commit comments

Comments
 (0)