Skip to content

[Task] : Fix LIT test + disable pass failure for unlowered torch ops. #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ void getBackendTypeConversionDependentDialects(DialectRegistry &registry);
void setupBackendTypeConversion(ConversionTarget &target,
TypeConverter &typeConverter);

void setupBackendTypeConversionForTosaLinalg(ConversionTarget &target,
TypeConverter &typeConverter);

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void setupBackendTypeConversionForStablehlo(ConversionTarget &target,
TypeConverter &typeConverter);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,12 @@ createVerifyLinalgOnTensorsBackendContractPass();
std::unique_ptr<OperationPass<ModuleOp>>
createVerifyTosaLinalgBackendContractPass();

std::unique_ptr<OperationPass<ModuleOp>>
createFuncBackendTypeConversionForTosaLinalgPass();

std::unique_ptr<InterfacePass<FunctionOpInterface>>
createFinalizingBackendTypeConversionForTosaLinalgPass();

} // namespace TorchConversion

/// Registers all Torch transformation passes.
Expand Down
20 changes: 20 additions & 0 deletions include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,26 @@ def VerifyTosaLinalgBackendContract : Pass<"torch-verify-tosa-linalg-backend-con
let constructor = "mlir::torch::TorchConversion::createVerifyTosaLinalgBackendContractPass()";
}

def FinalizingBackendTypeConversionForTosaLinalg
: InterfacePass<"torch-finalizing-backend-type-conversion-for-tosa-linalg", "mlir::FunctionOpInterface"> {
let summary = "Finalizes a partial conversion to builtin tensors for tosa+linalg";
let constructor =
"mlir::torch::TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass()";
let description = [{
Analogous in scope to the upstream `finalizing-bufferize` pass.
See details there.
}];
}

def FuncBackendTypeConversionForTosaLinalg : Pass<"torch-func-backend-type-conversion-for-tosa-linalg", "ModuleOp"> {
let summary = "Convert functions to operate on builtin tensors for tosa+linalg backend";
let constructor = "mlir::torch::TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass()";
let description = [{
Partial type conversion pass analogous in scope to the upstream
`func-bufferize` pass. See details there.
}];
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
let summary = "Verifies conformity to the stablehlo backend contract";
Expand Down
3 changes: 2 additions & 1 deletion lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ class ConvertTorchToTosaLinalg

TypeConverter typeConverter;
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversion(target, typeConverter);
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
typeConverter);

RewritePatternSet patterns(context);

Expand Down
27 changes: 27 additions & 0 deletions lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,33 @@ void mlir::torch::TorchConversion::setupBackendTypeConversion(
setupTorchGeneratorToI64Conversion(target, typeConverter);
}

void mlir::torch::TorchConversion::setupBackendTypeConversionForTosaLinalg(
ConversionTarget &target, TypeConverter &typeConverter) {
auto valueTensorTypeConversion =
[](Torch::ValueTensorType type) -> std::optional<Type> {
auto builtinType = type.toBuiltinTensor();
if (!builtinType)
return std::nullopt;

// convert signed integer type to signless, keep unsigned as unsigned
if (type.getDtype().isUnsignedInteger()) {
return builtinType.clone(type.getDtype());
} else if (type.getDtype().isSignedInteger()) {
return builtinType.clone(IntegerType::get(
builtinType.getContext(), type.getDtype().getIntOrFloatBitWidth(),
IntegerType::Signless));
}

return builtinType;
};
setupValueTensorToBuiltinTensorConversion(target, typeConverter,
valueTensorTypeConversion);
setupTorchBoolToI1Conversion(target, typeConverter);
setupTorchIntToI64Conversion(target, typeConverter);
setupTorchFloatToF64Conversion(target, typeConverter);
setupTorchGeneratorToI64Conversion(target, typeConverter);
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
void mlir::torch::TorchConversion::setupBackendTypeConversionForStablehlo(
ConversionTarget &target, TypeConverter &typeConverter) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "mlir/IR/BuiltinOps.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
Expand Down Expand Up @@ -97,6 +98,33 @@ struct FuncBackendTypeConversionPass
}
};

struct FuncBackendTypeConversionForTosaLinalgPass
: public FuncBackendTypeConversionForTosaLinalgBase<
FuncBackendTypeConversionForTosaLinalgPass> {
using FuncBackendTypeConversionForTosaLinalgBase<
FuncBackendTypeConversionForTosaLinalgPass>::
FuncBackendTypeConversionForTosaLinalgBase;
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<TorchConversion::TorchConversionDialect>();
}
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);
typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
typeConverter);

populateFuncBackendTypeConversionPatterns(typeConverter, patterns, target);

if (failed(applyFullConversion(module, target, std::move(patterns))))
signalPassFailure();
}
};

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct FuncBackendTypeConversionForStablehloPass
: public FuncBackendTypeConversionForStablehloBase<
Expand Down Expand Up @@ -132,6 +160,11 @@ mlir::torch::TorchConversion::createFuncBackendTypeConversionPass() {
return std::make_unique<FuncBackendTypeConversionPass>();
}

std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
createFuncBackendTypeConversionForTosaLinalgPass() {
return std::make_unique<FuncBackendTypeConversionForTosaLinalgPass>();
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
std::unique_ptr<OperationPass<ModuleOp>> mlir::torch::TorchConversion::
createFuncBackendTypeConversionForStablehloPass() {
Expand Down Expand Up @@ -240,6 +273,55 @@ struct FinalizingBackendTypeConversionPass
}
};

struct FinalizingBackendTypeConversionForTosaLinalgPass
: public FinalizingBackendTypeConversionForTosaLinalgBase<
FinalizingBackendTypeConversionForTosaLinalgPass> {
using FinalizingBackendTypeConversionForTosaLinalgBase<
FinalizingBackendTypeConversionForTosaLinalgPass>::
FinalizingBackendTypeConversionForTosaLinalgBase;

void runOnOperation() override {
auto func = getOperation();
auto *context = &getContext();

TypeConverter typeConverter;
RewritePatternSet patterns(context);
ConversionTarget target(*context);

typeConverter.addConversion([](Type type) { return type; });
TorchConversion::setupBackendTypeConversionForTosaLinalg(target,
typeConverter);

// Mark materializations as illegal in this pass (since we are finalizing)
// and add patterns that eliminate them.
setupFinalization<ToBuiltinTensorOp, FromBuiltinTensorOp, FromI1Op, ToI1Op,
FromI64Op, ToI64Op, FromF64Op, ToF64Op, I64ToGeneratorOp,
GeneratorToI64Op>(target, patterns, typeConverter);

// If all result types are legal, and all block arguments are legal, then
// all types in the program are legal.
//
// We also check that the operand types are legal to avoid creating invalid
// IR. For example, this prevents the patterns from updating
// the types of the operands to a return op without updating the enclosing
// function.
target.markUnknownOpDynamicallyLegal(
[&](Operation *op) { return typeConverter.isLegal(op); });

target.addLegalDialect<Torch::TorchDialect>();
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();

RewritePatternSet greedyPatterns(context);
greedyPatterns.insert<ExtFTruncFPattern>(context);
if (failed(applyPatternsGreedily(func, std::move(greedyPatterns))))
signalPassFailure();

// Drop attributes that are no longer used after conversion out of Torch.
stripTorchAttrs(func);
}
};

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
struct FinalizingBackendTypeConversionForStablehloPass
: public FinalizingBackendTypeConversionForStablehloBase<
Expand Down Expand Up @@ -291,6 +373,11 @@ mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() {
return std::make_unique<FinalizingBackendTypeConversionPass>();
}

std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
TorchConversion::createFinalizingBackendTypeConversionForTosaLinalgPass() {
return std::make_unique<FinalizingBackendTypeConversionForTosaLinalgPass>();
}

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
std::unique_ptr<InterfacePass<FunctionOpInterface>> mlir::torch::
TorchConversion::createFinalizingBackendTypeConversionForStablehloPass() {
Expand Down
11 changes: 4 additions & 7 deletions lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,12 @@ void TorchConversion::createTorchBackendToTosaLinalgBackendPipeline(

// Finish the type conversion from `torch` types to the types of the
// TOSA backend contract.
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());
pm.addPass(
TorchConversion::createFuncBackendTypeConversionForTosaLinalgPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());
pm.addNestedPass<func::FuncOp>(
TorchConversion::createFinalizingBackendTypeConversionPass());

// Verify that we have lowered to ops that are supported by union of TOSA and
// Linalg-on-tensors backend. This fails compilation (signalPassFailure) if
// the IR is not in the correct form.
pm.addPass(TorchConversion::createVerifyTosaLinalgBackendContractPass());
TorchConversion::
createFinalizingBackendTypeConversionForTosaLinalgPass());
}
#endif

Expand Down
88 changes: 54 additions & 34 deletions python/torch_mlir/fx_mw.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# Also available under a BSD-style license. See LICENSE.

import torch
import torch_mlir
from .compiler_utils import OutputType

from .compiler_utils_mw import (
Expand All @@ -14,6 +15,7 @@
from . import fx
from torch._decomp import get_decompositions


def import_exported_model(
prog: torch.export.ExportedProgram,
output_type: str,
Expand All @@ -25,6 +27,30 @@ def import_exported_model(
)
prog = prog.run_decompositions(decomp_table)

mlir_module = fx.export_and_import(
prog,
output_type=OutputType.RAW,
experimental_support_mutation=experimental_support_mutation,
)

if output_type != "raw":
mlir_module = lower_module(mlir_module, output_type)

return mlir_module


def lower_module_from_file(mlir_file: str, output_type: str):
src = open(mlir_file, "r").read()
with torch_mlir.ir.Context() as ctx:
torch_mlir.dialects.torch.register_dialect(ctx)
with torch_mlir.ir.Location.unknown() as loc:
mlir_module = torch_mlir.ir.Module.parse(src)

return lower_module(mlir_module, output_type)


def lower_module(mlir_module, output_type: str):

backend_legal_ops = None

match output_type:
Expand All @@ -39,17 +65,18 @@ def import_exported_model(
"aten.adaptive_avg_pool2d",
"aten.adaptive_max_pool1d",
"aten.adaptive_max_pool2d",
"aten.linear"]
"aten.linear",
]
case "linalg_on_tensors":
output_type = OutputType.LINALG_ON_TENSORS
backend_legal_ops = [
"aten.flatten.using_ints",
"aten.adaptive_avg_pool1d",
"aten.adaptive_avg_pool2d",
"aten.adaptive_max_pool1d",
"aten.adaptive_max_pool2d",
"aten.adaptive_max_pool2d",
"aten.unflatten.int",
]
]
case "tosa_linalg":
output_type = OutputType.TOSA_LINALG
backend_legal_ops = [
Expand All @@ -60,41 +87,34 @@ def import_exported_model(
"aten.adaptive_max_pool1d",
"aten.adaptive_max_pool2d",
"aten.linear",
"aten.unflatten.int"]
"aten.unflatten.int",
]
case "raw":
output_type = OutputType.RAW
case _:
raise ValueError("Importing PyTorch model failed: Unsupported output type.")

mlir_module = fx.export_and_import(
prog,
output_type=OutputType.RAW,
experimental_support_mutation=experimental_support_mutation,
)
backend_legal_op_arg_str = ""
if backend_legal_ops is not None:
if not len(backend_legal_ops) == 0:
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
backend_legal_ops
)

if output_type != OutputType.RAW:
backend_legal_op_arg_str = ""
if backend_legal_ops is not None:
if not len(backend_legal_ops) == 0:
backend_legal_op_arg_str = "backend-legal-ops=" + ",".join(
backend_legal_ops
)

extra_library_file_name = ""
option_string = (
"{"
+ backend_legal_op_arg_str
+ " extra-library="
+ extra_library_file_name
+ "}"
)
run_pipeline_mw(
mlir_module,
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=False,
)
verbose = False
mlir_module = lower_mlir_module_mw(verbose, output_type, mlir_module)
extra_library_file_name = ""
option_string = (
"{"
+ backend_legal_op_arg_str
+ " extra-library="
+ extra_library_file_name
+ "}"
)
run_pipeline_mw(
mlir_module,
f"builtin.module(func.func(torch-match-quantized-custom-ops), torchdynamo-export-to-torch-backend-pipeline{option_string})",
"Lowering TorchFX IR -> Torch Backend IR",
enable_ir_printing=False,
)

return mlir_module
verbose = False
return lower_mlir_module_mw(verbose, output_type, mlir_module)
10 changes: 5 additions & 5 deletions test/Conversion/TorchToTosa/basic.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -2149,7 +2149,7 @@ func.func @torch.aten.diagonal$basic(%arg0: !torch.vtensor<[3,4,5,6], si32>) ->
// CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_5]], %[[VAL_9]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi1>
// CHECK: %[[VAL_11:.*]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} : () -> !tosa.shape<1>
// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_11]] : (tensor<i32>, !tosa.shape<1>) -> tensor<1xi32>
// CHECK: %[[VAL_13:.*]] = tosa.int_div %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
// CHECK: %[[VAL_13:.*]] = tosa.intdiv %[[VAL_5]], %[[VAL_12]] : (tensor<2xi32>, tensor<1xi32>) -> tensor<2xi32>
// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
// CHECK: %[[VAL_15:.*]] = tosa.mul %[[VAL_13]], %[[VAL_13]], %[[VAL_14]] : (tensor<2xi32>, tensor<2xi32>, tensor<1xi8>) -> tensor<2xi32>
// CHECK: %[[VAL_16:.*]] = tosa.sub %[[VAL_5]], %[[VAL_15]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi32>
Expand Down Expand Up @@ -3920,11 +3920,11 @@ func.func @torch.aten.convolution$full_dim_indivisible_by_stride_without_sliced_
// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<32xf32>}> : () -> tensor<32xf32>
// CHECK: %[[VAL_12:.*]] = tosa.transpose %[[VAL_1]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<?x3x225x225xf32>) -> tensor<?x225x225x3xf32>
// CHECK: %[[VAL_13:.*]] = tosa.transpose %[[VAL_4]] {perms = array<i32: 0, 2, 3, 1>} : (tensor<32x3x3x3xf32>) -> tensor<32x3x3x3xf32>
// CHECK: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[-1, 224, 225, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_12]], %[[VAL_14]], %[[VAL_15]] : (tensor<?x225x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x225x3xf32>
// CHECK: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_17:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[-1, 224, 224, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_16]], %[[VAL_17]], %[[VAL_18]] : (tensor<?x224x225x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x224x224x3xf32>
// CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK: %[[VAL_21:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
Expand Down
Loading