Skip to content

Commit 0dc37ee

Browse files
committed
[Task] : Fix LIT test + disable pass failure for unlowered torch ops.
1 parent ae0e2bf commit 0dc37ee

File tree

9 files changed

+199
-43
lines changed

9 files changed

+199
-43
lines changed

include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ void getBackendTypeConversionDependentDialects(DialectRegistry &registry);
2626
void setupBackendTypeConversion(ConversionTarget &target,
2727
TypeConverter &typeConverter);
2828

29+
void setupBackendTypeConversionForTosaLinalg(ConversionTarget &target,
30+
TypeConverter &typeConverter);
31+
2932
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
3033
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
3134
TypeConverter &typeConverter);

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ createVerifyLinalgOnTensorsBackendContractPass();
9999
std::unique_ptr<OperationPass<ModuleOp>>
100100
createVerifyTosaLinalgBackendContractPass();
101101

102+
std::unique_ptr<OperationPass<ModuleOp>>
103+
createFuncBackendTypeConversionForTosaLinalgPass();
104+
105+
std::unique_ptr<InterfacePass<FunctionOpInterface>>
106+
createFinalizingBackendTypeConversionForTosaLinalgPass();
107+
102108
} // namespace TorchConversion
103109

104110
/// Registers all Torch transformation passes.

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,26 @@ def VerifyTosaLinalgBackendContract : Pass<"torch-verify-tosa-linalg-backend-con
7373
let constructor = "mlir::torch::TorchConversion::createVerifyTosaLinalgBackendContractPass()";
7474
}
7575

76+
def FinalizingBackendTypeConversionForTosaLinalg
77+
: InterfacePass<"torch-finalizing-backend-type-conversion-for-tosa-linalg", "mlir::FunctionOpInterface"> {
78+
let summary = "Finalizes a partial conversion to builtin tensors for tosa+linalg";
79+
let constructor =
80+
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass()";
81+
let description = [{
82+
Analogous in scope to the upstream `finalizing-bufferize` pass.
83+
See details there.
84+
}];
85+
}
86+
87+
def FuncBackendTypeConversionForTosaLinalg : Pass<"torch-func-backend-type-conversion-for-tosa-linalg", "ModuleOp"> {
88+
let summary = "Convert functions to operate on builtin tensors for tosa+linalg backend";
89+
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass()";
90+
let description = [{
91+
Partial type conversion pass analogous in scope to the upstream
92+
`func-bufferize` pass. See details there.
93+
}];
94+
}
95+
7696
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
7797
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
7898
let summary = "Verifies conformity to the stablehlo backend contract";

lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ class ConvertTorchToTosaLinalg
6565

6666
TypeConverter typeConverter;
6767
typeConverter.addConversion([](Type type) { return type; });
68-
TorchConversion::setupBackendTypeConversion(target, typeConverter);
68+
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
69+
typeConverter);
6970

7071
RewritePatternSet patterns(context);
7172

lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,33 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
185185
setupTorchGeneratorToI64Conversion(target, typeConverter);
186186
}
187187

188+
void mlir::torch::TorchConversion::setupBackendTypeConversionForTosaLinalg(
189+
ConversionTarget &target, TypeConverter &typeConverter) {
190+
auto valueTensorTypeConversion =
191+
[](Torch::ValueTensorType type) -> std::optional<Type> {
192+
auto builtinType = type.toBuiltinTensor();
193+
if (!builtinType)
194+
return std::nullopt;
195+
196+
// convert signed integer type to signless, keep unsigned as unsigned
197+
if (type.getDtype().isUnsignedInteger()) {
198+
return builtinType.clone(type.getDtype());
199+
} else if (type.getDtype().isSignedInteger()) {
200+
return builtinType.clone(IntegerType::get(
201+
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
202+
IntegerType::Signless));
203+
}
204+
205+
return builtinType;
206+
};
207+
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
208+
valueTensorTypeConversion);
209+
setupTorchBoolToI1Conversion(target, typeConverter);
210+
setupTorchIntToI64Conversion(target, typeConverter);
211+
setupTorchFloatToF64Conversion(target, typeConverter);
212+
setupTorchGeneratorToI64Conversion(target, typeConverter);
213+
}
214+
188215
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
189216
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
190217
ConversionTarget &target, TypeConverter &typeConverter) {

lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/IR/BuiltinOps.h"
1616
#include "mlir/Transforms/DialectConversion.h"
1717
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
1819
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
1920
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
2021
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
@@ -97,6 +98,33 @@ struct FuncBackendTypeConversionPass
9798
}
9899
};
99100

101+
struct FuncBackendTypeConversionForTosaLinalgPass
102+
: public FuncBackendTypeConversionForTosaLinalgBase<
103+
FuncBackendTypeConversionForTosaLinalgPass> {
104+
using FuncBackendTypeConversionForTosaLinalgBase<
105+
FuncBackendTypeConversionForTosaLinalgPass>::
106+
FuncBackendTypeConversionForTosaLinalgBase;
107+
void getDependentDialects(DialectRegistry &registry) const override {
108+
registry.insert<TorchConversion::TorchConversionDialect>();
109+
}
110+
void runOnOperation() override {
111+
auto module = getOperation();
112+
auto *context = &getContext();
113+
114+
TypeConverter typeConverter;
115+
RewritePatternSet patterns(context);
116+
ConversionTarget target(*context);
117+
typeConverter.addConversion([](Type type) { return type; });
118+
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
119+
typeConverter);
120+
121+
populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);
122+
123+
if (failed(applyFullConversion(module, target, std::move(patterns))))
124+
signalPassFailure();
125+
}
126+
};
127+
100128
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
101129
struct FuncBackendTypeConversionForStablehloPass
102130
: public FuncBackendTypeConversionForStablehloBase<
@@ -132,6 +160,11 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
132160
return std::make_unique<FuncBackendTypeConversionPass>();
133161
}
134162

163+
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
164+
createFuncBackendTypeConversionForTosaLinalgPass() {
165+
return std::make_unique<FuncBackendTypeConversionForTosaLinalgPass>();
166+
}
167+
135168
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
136169
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
137170
createFuncBackendTypeConversionForStablehloPass() {
@@ -240,6 +273,55 @@ struct FinalizingBackendTypeConversionPass
240273
}
241274
};
242275

276+
struct FinalizingBackendTypeConversionForTosaLinalgPass
277+
: public FinalizingBackendTypeConversionForTosaLinalgBase<
278+
FinalizingBackendTypeConversionForTosaLinalgPass> {
279+
using FinalizingBackendTypeConversionForTosaLinalgBase<
280+
FinalizingBackendTypeConversionForTosaLinalgPass>::
281+
FinalizingBackendTypeConversionForTosaLinalgBase;
282+
283+
void runOnOperation() override {
284+
auto func = getOperation();
285+
auto *context = &getContext();
286+
287+
TypeConverter typeConverter;
288+
RewritePatternSet patterns(context);
289+
ConversionTarget target(*context);
290+
291+
typeConverter.addConversion([](Type type) { return type; });
292+
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
293+
typeConverter);
294+
295+
// Mark materializations as illegal in this pass (since we are finalizing)
296+
// and add patterns that eliminate them.
297+
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
298+
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
299+
GeneratorToI64Op>(target, patterns, typeConverter);
300+
301+
// If all result types are legal, and all block arguments are legal, then
302+
// all types in the program are legal.
303+
//
304+
// We also check that the operand types are legal to avoid creating invalid
305+
// IR. For example, this prevents the patterns from updating
306+
// the types of the operands to a return op without updating the enclosing
307+
// function.
308+
target.markUnknownOpDynamicallyLegal(
309+
[&](Operation *op) { return typeConverter.isLegal(op); });
310+
311+
target.addLegalDialect<Torch::TorchDialect>();
312+
if (failed(applyFullConversion(func, target, std::move(patterns))))
313+
signalPassFailure();
314+
315+
RewritePatternSet greedyPatterns(context);
316+
greedyPatterns.insert<ExtFTruncFPattern>(context);
317+
if (failed(applyPatternsGreedily(func, std::move(greedyPatterns))))
318+
signalPassFailure();
319+
320+
// Drop attributes that are no longer used after conversion out of Torch.
321+
stripTorchAttrs(func);
322+
}
323+
};
324+
243325
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
244326
struct FinalizingBackendTypeConversionForStablehloPass
245327
: public FinalizingBackendTypeConversionForStablehloBase<
@@ -291,6 +373,11 @@ mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
291373
return std::make_unique<FinalizingBackendTypeConversionPass>();
292374
}
293375

376+
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
377+
TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass() {
378+
return std::make_unique<FinalizingBackendTypeConversionForTosaLinalgPass>();
379+
}
380+
294381
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
295382
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
296383
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() {

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,12 @@ void TorchConversion::createTorchBackendToTosaLinalgBackendPipeline(
199199

200200
// Finish the type conversion from `torch` types to the types of the
201201
// TOSA backend contract.
202-
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
202+
pm.addPass(
203+
TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass());
203204
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
204205
pm.addNestedPass<func::FuncOp>(
205-
TorchConversion::createFinalizingBackendTypeConversionPass());
206-
207-
// Verify that we have lowered to ops that are supported by union of TOSA and
208-
// Linalg-on-tensors backend. This fails compilation (signalPassFailure) if
209-
// the IR is not in the correct form.
210-
pm.addPass(TorchConversion::createVerifyTosaLinalgBackendContractPass());
206+
TorchConversion::
207+
createFinalizingBackendTypeConversionForTosaLinalgPass());
211208
}
212209
#endif
213210

python/torch_mlir/fx_mw.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# Also available under a BSD-style license. See LICENSE.
55

66
import torch
7+
import torch_mlir
78
from .compiler_utils import OutputType
89

910
from .compiler_utils_mw import (
@@ -25,6 +26,28 @@ def import_exported_model(
2526
)
2627
prog = prog.run_decompositions(decomp_table)
2728

29+
mlir_module = fx.export_and_import(
30+
prog,
31+
output_type=OutputType.RAW,
32+
experimental_support_mutation=experimental_support_mutation,
33+
)
34+
35+
if output_type != 'raw':
36+
mlir_module = lower_module(mlir_module, output_type)
37+
38+
return mlir_module
39+
40+
def lower_module_from_file(mlir_file: str, output_type: str):
41+
src = open(mlir_file, "r").read()
42+
with torch_mlir.ir.Context() as ctx:
43+
torch_mlir.dialects.torch.register_dialect(ctx)
44+
with torch_mlir.ir.Location.unknown() as loc:
45+
mlir_module = torch_mlir.ir.Module.parse(src)
46+
47+
return lower_module(mlir_module, output_type)
48+
49+
def lower_module(mlir_module, output_type: str):
50+
2851
backend_legal_ops = None
2952

3053
match output_type:
@@ -65,36 +88,28 @@ def import_exported_model(
6588
output_type = OutputType.RAW
6689
case _:
6790
raise ValueError("Importing PyTorch model failed: Unsupported output type.")
91+
92+
backend_legal_op_arg_str = ""
93+
if backend_legal_ops is not None:
94+
if not len(backend_legal_ops) == 0:
95+
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
96+
backend_legal_ops
97+
)
6898

69-
mlir_module = fx.export_and_import(
70-
prog,
71-
output_type=OutputType.RAW,
72-
experimental_support_mutation=experimental_support_mutation,
99+
extra_library_file_name = ""
100+
option_string = (
101+
"{"
102+
+ backend_legal_op_arg_str
103+
+ " extra-library="
104+
+ extra_library_file_name
105+
+ "}"
106+
)
107+
run_pipeline_mw(
108+
mlir_module,
109+
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
110+
"Lowering TorchFX IR -> Torch Backend IR",
111+
enable_ir_printing=False,
73112
)
74113

75-
if output_type != OutputType.RAW:
76-
backend_legal_op_arg_str = ""
77-
if backend_legal_ops is not None:
78-
if not len(backend_legal_ops) == 0:
79-
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
80-
backend_legal_ops
81-
)
82-
83-
extra_library_file_name = ""
84-
option_string = (
85-
"{"
86-
+ backend_legal_op_arg_str
87-
+ " extra-library="
88-
+ extra_library_file_name
89-
+ "}"
90-
)
91-
run_pipeline_mw(
92-
mlir_module,
93-
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
94-
"Lowering TorchFX IR -> Torch Backend IR",
95-
enable_ir_printing=False,
96-
)
97-
verbose = False
98-
mlir_module = lower_mlir_module_mw(verbose, output_type, mlir_module)
99-
100-
return mlir_module
114+
verbose = False
115+
return lower_mlir_module_mw(verbose, output_type, mlir_module)

test/Conversion/TorchToTosa/basic.mlir

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2149,7 +2149,7 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
21492149
// CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_5]], %[[VAL_9]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi1>
21502150
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
21512151
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_11]] : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
2152-
// CHECK: %[[VAL_13:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
2152+
// CHECK: %[[VAL_13:.*]] = tosa.intdiv %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
21532153
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
21542154
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_13]], %[[VAL_14]] : (tensor<2xi32>, tensor<2xi32>, tensor<1xi8>) -> tensor<2xi32>
21552155
// CHECK: %[[VAL_16:.*]] = tosa.sub %[[VAL_5]], %[[VAL_15]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
@@ -3920,11 +3920,11 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
39203920
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
39213921
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
39223922
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
3923-
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3924-
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3923+
// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3924+
// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
39253925
// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
3926-
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3927-
// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
3926+
// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
3927+
// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
39283928
// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
39293929
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
39303930
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>

0 commit comments

Comments
 (0)