From 0dc37ee2f0a35b4e7c081cf586a4d1dc677c57b0 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Mon, 5 May 2025 08:57:25 -0400 Subject: [PATCH 1/2] [Task] : Fix LIT test + disable pass failure for unlowered torch ops. --- .../Transforms/BackendTypeConversion.h | 3 + .../TorchConversion/Transforms/Passes.h | 6 ++ .../TorchConversion/Transforms/Passes.td | 20 +++++ .../TorchToTosaLinalg/TorchToTosaLinalg.cpp | 3 +- .../Transforms/BackendTypeConversion.cpp | 27 ++++++ .../BackendTypeConversionPasses.cpp | 87 +++++++++++++++++++ .../TorchConversion/Transforms/Passes.cpp | 11 +-- python/torch_mlir/fx_mw.py | 75 +++++++++------- test/Conversion/TorchToTosa/basic.mlir | 10 +-- 9 files changed, 199 insertions(+), 43 deletions(-) diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h index b0a085eab7f0..882297504eb9 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h @@ -26,6 +26,9 @@ void getBackendTypeConversionDependentDialects(DialectRegistry ®istry); void setupBackendTypeConversion(ConversionTarget &target, TypeConverter &typeConverter); +void setupBackendTypeConversionForTosaLinalg(ConversionTarget &target, + TypeConverter &typeConverter); + #ifdef TORCH_MLIR_ENABLE_STABLEHLO void setupBackendTypeConversionForStablehlo(ConversionTarget &target, TypeConverter &typeConverter); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index 675f2e7ae948..55579746f269 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -99,6 +99,12 @@ createVerifyLinalgOnTensorsBackendContractPass(); std::unique_ptr> createVerifyTosaLinalgBackendContractPass(); +std::unique_ptr> +createFuncBackendTypeConversionForTosaLinalgPass(); + +std::unique_ptr> +createFinalizingBackendTypeConversionForTosaLinalgPass(); + } // namespace TorchConversion /// Registers all Torch transformation passes. diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 0719f12e8566..72157788b759 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -73,6 +73,26 @@ def VerifyTosaLinalgBackendContract : Pass<"torch-verify-tosa-linalg-backend-con let constructor = "mlir::torch::TorchConversion::createVerifyTosaLinalgBackendContractPass()"; } +def FinalizingBackendTypeConversionForTosaLinalg + : InterfacePass<"torch-finalizing-backend-type-conversion-for-tosa-linalg", "mlir::FunctionOpInterface"> { + let summary = "Finalizes a partial conversion to builtin tensors for tosa+linalg"; + let constructor = + "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass()"; + let description = [{ + Analogous in scope to the upstream `finalizing-bufferize` pass. + See details there. + }]; +} + +def FuncBackendTypeConversionForTosaLinalg : Pass<"torch-func-backend-type-conversion-for-tosa-linalg", "ModuleOp"> { + let summary = "Convert functions to operate on builtin tensors for tosa+linalg backend"; + let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass()"; + let description = [{ + Partial type conversion pass analogous in scope to the upstream + `func-bufferize` pass. See details there. + }]; +} + #ifdef TORCH_MLIR_ENABLE_STABLEHLO def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { let summary = "Verifies conformity to the stablehlo backend contract"; diff --git a/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp b/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp index 3609d158d603..2bbe0caaf7d3 100644 --- a/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp +++ b/lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp @@ -65,7 +65,8 @@ class ConvertTorchToTosaLinalg TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); - TorchConversion::setupBackendTypeConversion(target, typeConverter); + TorchConversion::setupBackendTypeConversionForTosaLinalg(target, + typeConverter); RewritePatternSet patterns(context); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 53de48f21934..23a2f00e115e 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -185,6 +185,33 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion( setupTorchGeneratorToI64Conversion(target, typeConverter); } +void mlir::torch::TorchConversion::setupBackendTypeConversionForTosaLinalg( + ConversionTarget &target, TypeConverter &typeConverter) { + auto valueTensorTypeConversion = + [](Torch::ValueTensorType type) -> std::optional { + auto builtinType = type.toBuiltinTensor(); + if (!builtinType) + return std::nullopt; + + // convert signed integer type to signless, keep unsigned as unsigned + if (type.getDtype().isUnsignedInteger()) { + return builtinType.clone(type.getDtype()); + } else if (type.getDtype().isSignedInteger()) { + return builtinType.clone(IntegerType::get( + builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(), + IntegerType::Signless)); + } + + return builtinType; + }; + setupValueTensorToBuiltinTensorConversion(target, typeConverter, + valueTensorTypeConversion); + setupTorchBoolToI1Conversion(target, typeConverter); + setupTorchIntToI64Conversion(target, typeConverter); + setupTorchFloatToF64Conversion(target, typeConverter); + setupTorchGeneratorToI64Conversion(target, typeConverter); +} + #ifdef TORCH_MLIR_ENABLE_STABLEHLO void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo( ConversionTarget &target, TypeConverter &typeConverter) { diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index dadd865a54a7..6a7b2b2a7b9f 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -15,6 +15,7 @@ #include "mlir/IR/BuiltinOps.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" @@ -97,6 +98,33 @@ struct FuncBackendTypeConversionPass } }; +struct FuncBackendTypeConversionForTosaLinalgPass + : public FuncBackendTypeConversionForTosaLinalgBase< + FuncBackendTypeConversionForTosaLinalgPass> { + using FuncBackendTypeConversionForTosaLinalgBase< + FuncBackendTypeConversionForTosaLinalgPass>:: + FuncBackendTypeConversionForTosaLinalgBase; + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + auto module = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForTosaLinalg(target, + typeConverter); + + populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target); + + if (failed(applyFullConversion(module, target, std::move(patterns)))) + signalPassFailure(); + } +}; + #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FuncBackendTypeConversionForStablehloPass : public FuncBackendTypeConversionForStablehloBase< @@ -132,6 +160,11 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() { return std::make_unique(); } +std::unique_ptr> mlir::torch::TorchConversion:: + createFuncBackendTypeConversionForTosaLinalgPass() { + return std::make_unique(); +} + #ifdef TORCH_MLIR_ENABLE_STABLEHLO std::unique_ptr> mlir::torch::TorchConversion:: createFuncBackendTypeConversionForStablehloPass() { @@ -240,6 +273,55 @@ struct FinalizingBackendTypeConversionPass } }; +struct FinalizingBackendTypeConversionForTosaLinalgPass + : public FinalizingBackendTypeConversionForTosaLinalgBase< + FinalizingBackendTypeConversionForTosaLinalgPass> { + using FinalizingBackendTypeConversionForTosaLinalgBase< + FinalizingBackendTypeConversionForTosaLinalgPass>:: + FinalizingBackendTypeConversionForTosaLinalgBase; + + void runOnOperation() override { + auto func = getOperation(); + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversionForTosaLinalg(target, + typeConverter); + + // Mark materializations as illegal in this pass (since we are finalizing) + // and add patterns that eliminate them. + setupFinalization(target, patterns, typeConverter); + + // If all result types are legal, and all block arguments are legal, then + // all types in the program are legal. + // + // We also check that the operand types are legal to avoid creating invalid + // IR. For example, this prevents the patterns from updating + // the types of the operands to a return op without updating the enclosing + // function. + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return typeConverter.isLegal(op); }); + + target.addLegalDialect(); + if (failed(applyFullConversion(func, target, std::move(patterns)))) + signalPassFailure(); + + RewritePatternSet greedyPatterns(context); + greedyPatterns.insert(context); + if (failed(applyPatternsGreedily(func, std::move(greedyPatterns)))) + signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); + } +}; + #ifdef TORCH_MLIR_ENABLE_STABLEHLO struct FinalizingBackendTypeConversionForStablehloPass : public FinalizingBackendTypeConversionForStablehloBase< @@ -291,6 +373,11 @@ mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } +std::unique_ptr> mlir::torch:: + TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass() { + return std::make_unique(); +} + #ifdef TORCH_MLIR_ENABLE_STABLEHLO std::unique_ptr> mlir::torch:: TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 16cb8bf72081..84234fade9d5 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -199,15 +199,12 @@ void TorchConversion::createTorchBackendToTosaLinalgBackendPipeline( // Finish the type conversion from `torch` types to the types of the // TOSA backend contract. - pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); + pm.addPass( + TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass( - TorchConversion::createFinalizingBackendTypeConversionPass()); - - // Verify that we have lowered to ops that are supported by union of TOSA and - // Linalg-on-tensors backend. This fails compilation (signalPassFailure) if - // the IR is not in the correct form. - pm.addPass(TorchConversion::createVerifyTosaLinalgBackendContractPass()); + TorchConversion:: + createFinalizingBackendTypeConversionForTosaLinalgPass()); } #endif diff --git a/python/torch_mlir/fx_mw.py b/python/torch_mlir/fx_mw.py index be579f719f11..dc5be99f0d64 100644 --- a/python/torch_mlir/fx_mw.py +++ b/python/torch_mlir/fx_mw.py @@ -4,6 +4,7 @@ # Also available under a BSD-style license. See LICENSE. import torch +import torch_mlir from .compiler_utils import OutputType from .compiler_utils_mw import ( @@ -25,6 +26,28 @@ def import_exported_model( ) prog = prog.run_decompositions(decomp_table) + mlir_module = fx.export_and_import( + prog, + output_type=OutputType.RAW, + experimental_support_mutation=experimental_support_mutation, + ) + + if output_type != 'raw': + mlir_module = lower_module(mlir_module, output_type) + + return mlir_module + +def lower_module_from_file(mlir_file: str, output_type: str): + src = open(mlir_file, "r").read() + with torch_mlir.ir.Context() as ctx: + torch_mlir.dialects.torch.register_dialect(ctx) + with torch_mlir.ir.Location.unknown() as loc: + mlir_module = torch_mlir.ir.Module.parse(src) + + return lower_module(mlir_module, output_type) + +def lower_module(mlir_module, output_type: str): + backend_legal_ops = None match output_type: @@ -65,36 +88,28 @@ def import_exported_model( output_type = OutputType.RAW case _: raise ValueError("Importing PyTorch model failed: Unsupported output type.") + + backend_legal_op_arg_str = "" + if backend_legal_ops is not None: + if not len(backend_legal_ops) == 0: + backend_legal_op_arg_str = "backend-legal-ops=" + ",".join( + backend_legal_ops + ) - mlir_module = fx.export_and_import( - prog, - output_type=OutputType.RAW, - experimental_support_mutation=experimental_support_mutation, + extra_library_file_name = "" + option_string = ( + "{" + + backend_legal_op_arg_str + + " extra-library=" + + extra_library_file_name + + "}" + ) + run_pipeline_mw( + mlir_module, + f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})", + "Lowering TorchFX IR -> Torch Backend IR", + enable_ir_printing=False, ) - if output_type != OutputType.RAW: - backend_legal_op_arg_str = "" - if backend_legal_ops is not None: - if not len(backend_legal_ops) == 0: - backend_legal_op_arg_str = "backend-legal-ops=" + ",".join( - backend_legal_ops - ) - - extra_library_file_name = "" - option_string = ( - "{" - + backend_legal_op_arg_str - + " extra-library=" - + extra_library_file_name - + "}" - ) - run_pipeline_mw( - mlir_module, - f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})", - "Lowering TorchFX IR -> Torch Backend IR", - enable_ir_printing=False, - ) - verbose = False - mlir_module = lower_mlir_module_mw(verbose, output_type, mlir_module) - - return mlir_module + verbose = False + return lower_mlir_module_mw(verbose, output_type, mlir_module) diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 57c754d277e9..c4c9ac8180c4 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -2149,7 +2149,7 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) -> // CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_5]], %[[VAL_9]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi1> // CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1> // CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_11]] : (tensor, !tosa.shape<1>) -> tensor<1xi32> -// CHECK: %[[VAL_13:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %[[VAL_13:.*]] = tosa.intdiv %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32> // CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> // CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_13]], %[[VAL_14]] : (tensor<2xi32>, tensor<2xi32>, tensor<1xi8>) -> tensor<2xi32> // CHECK: %[[VAL_16:.*]] = tosa.sub %[[VAL_5]], %[[VAL_15]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32> @@ -3920,11 +3920,11 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_ // CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32> // CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array} : (tensor) -> tensor // CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32> -// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor -// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor, !tosa.shape<4>, !tosa.shape<4>) -> tensor // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> From a044cc45de7c321ff7902ddc35a76f4ff4816d01 Mon Sep 17 00:00:00 2001 From: Sayan Saha Date: Fri, 9 May 2025 13:36:42 -0400 Subject: [PATCH 2/2] [Task] : Fix pre-commit. --- python/torch_mlir/fx_mw.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/python/torch_mlir/fx_mw.py b/python/torch_mlir/fx_mw.py index dc5be99f0d64..7641e469d352 100644 --- a/python/torch_mlir/fx_mw.py +++ b/python/torch_mlir/fx_mw.py @@ -15,6 +15,7 @@ from . import fx from torch._decomp import get_decompositions + def import_exported_model( prog: torch.export.ExportedProgram, output_type: str, @@ -32,20 +33,22 @@ def import_exported_model( experimental_support_mutation=experimental_support_mutation, ) - if output_type != 'raw': + if output_type != "raw": mlir_module = lower_module(mlir_module, output_type) return mlir_module + def lower_module_from_file(mlir_file: str, output_type: str): src = open(mlir_file, "r").read() with torch_mlir.ir.Context() as ctx: torch_mlir.dialects.torch.register_dialect(ctx) with torch_mlir.ir.Location.unknown() as loc: mlir_module = torch_mlir.ir.Module.parse(src) - + return lower_module(mlir_module, output_type) + def lower_module(mlir_module, output_type: str): backend_legal_ops = None @@ -62,7 +65,8 @@ def lower_module(mlir_module, output_type: str): "aten.adaptive_avg_pool2d", "aten.adaptive_max_pool1d", "aten.adaptive_max_pool2d", - "aten.linear"] + "aten.linear", + ] case "linalg_on_tensors": output_type = OutputType.LINALG_ON_TENSORS backend_legal_ops = [ @@ -70,9 +74,9 @@ def lower_module(mlir_module, output_type: str): "aten.adaptive_avg_pool1d", "aten.adaptive_avg_pool2d", "aten.adaptive_max_pool1d", - "aten.adaptive_max_pool2d", + "aten.adaptive_max_pool2d", "aten.unflatten.int", - ] + ] case "tosa_linalg": output_type = OutputType.TOSA_LINALG backend_legal_ops = [ @@ -83,12 +87,13 @@ def lower_module(mlir_module, output_type: str): "aten.adaptive_max_pool1d", "aten.adaptive_max_pool2d", "aten.linear", - "aten.unflatten.int"] + "aten.unflatten.int", + ] case "raw": output_type = OutputType.RAW case _: raise ValueError("Importing PyTorch model failed: Unsupported output type.") - + backend_legal_op_arg_str = "" if backend_legal_ops is not None: if not len(backend_legal_ops) == 0: