Skip to content

Commit fa9d385

Browse files
committed
Add a pass pipeline to convert Torch to a mix of Tosa, Linalg, Tensor and other standard MLIR dialects.
1 parent 4f3a60b commit fa9d385

File tree

18 files changed

+424
-19
lines changed

18 files changed

+424
-19
lines changed

include/torch-mlir/Conversion/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,14 @@ def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> {
132132
}
133133
#endif
134134

135+
def ConvertTorchToTosaLinalg : Pass<"convert-torch-to-tosa-linalg", "func::FuncOp"> {
136+
let summary = "Convert Torch ops to a mix of TOSA ops and LINALG_ON_TENSORS ops";
137+
let description = [{
138+
This pass tries to lower torch ops to tosa ops if possible. Otherwise lowers to a mix of linalg, tensor, scf, and other standard mlir dialects.
139+
}];
140+
let constructor = "mlir::torch::createConvertTorchToTosaLinalgPass()";
141+
}
142+
135143
def ConvertTorchToTMTensor : Pass<"convert-torch-to-tmtensor", "func::FuncOp"> {
136144
let summary = "Convert recognized Torch ops to TMTensor/Linalg ops";
137145
let description = [{

include/torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,18 @@
1313
#include "mlir/Dialect/Func/IR/FuncOps.h"
1414
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1515
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Transforms/DialectConversion.h"
1617
#include <memory>
1718

1819
namespace mlir {
1920
namespace torch {
21+
22+
void populateTorchToLinalgOnTensorsPatternsAndLegality(
23+
TypeConverter &typeConverter, RewritePatternSet &patterns,
24+
ConversionTarget &target);
25+
2026
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToLinalgPass();
21-
}
27+
} // namespace torch
2228
} // namespace mlir
2329

2430
#endif // TORCHMLIR_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
//===------------------------------------------------------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#ifndef TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H
11+
#define TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H
12+
13+
#include "mlir/Dialect/Func/IR/FuncOps.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include <memory>
16+
17+
namespace mlir {
18+
namespace torch {
19+
std::unique_ptr<OperationPass<func::FuncOp>>
20+
createConvertTorchToTosaLinalgPass();
21+
}
22+
} // namespace mlir
23+
24+
#endif // TORCHMLIR_CONVERSION_ATENTOTOSALINALG_ATENTOTOSALINALG_H

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ void createTorchBackendToTosaBackendPipeline(
4545
std::unique_ptr<OperationPass<ModuleOp>> createVerifyTosaBackendContractPass();
4646
#endif // TORCH_MLIR_ENABLE_TOSA
4747

48+
/// Creates a pipeline that lowers from the torch backend contract to the
49+
/// TOSA + linalg backend contract.
50+
void createTorchBackendToTosaLinalgBackendPipeline(OpPassManager &pm);
51+
4852
// Do not register the stablehlo options if the stablehlo target is disabled
4953
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
5054
struct StablehloBackendPipelineOptions
@@ -92,6 +96,9 @@ createConvertCustomQuantOpPass();
9296
std::unique_ptr<OperationPass<ModuleOp>>
9397
createVerifyLinalgOnTensorsBackendContractPass();
9498

99+
std::unique_ptr<OperationPass<ModuleOp>>
100+
createVerifyTosaLinalgBackendContractPass();
101+
95102
} // namespace TorchConversion
96103

97104
/// Registers all Torch transformation passes.

include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,11 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu
6868
}
6969
#endif
7070

71+
def VerifyTosaLinalgBackendContract : Pass<"torch-verify-tosa-linalg-backend-contract", "ModuleOp"> {
72+
let summary = "Verifies conformity to the tosa + linalg-on-tensors backend contract";
73+
let constructor = "mlir::torch::TorchConversion::createVerifyTosaLinalgBackendContractPass()";
74+
}
75+
7176
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
7277
def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> {
7378
let summary = "Verifies conformity to the stablehlo backend contract";

lib/Conversion/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ add_subdirectory(TorchToSCF)
55
add_subdirectory(TorchToTensor)
66
if(TORCH_MLIR_ENABLE_TOSA)
77
add_subdirectory(TorchToTosa)
8+
add_subdirectory(TorchToTosaLinalg)
89
endif()
910
if(TORCH_MLIR_ENABLE_STABLEHLO)
1011
add_subdirectory(TorchToStablehlo)
@@ -25,7 +26,7 @@ if(TORCH_MLIR_ENABLE_STABLEHLO)
2526
list(APPEND linked_libs TorchMLIRTorchToStablehlo)
2627
endif()
2728
if(TORCH_MLIR_ENABLE_TOSA)
28-
list(APPEND linked_libs TorchMLIRTorchToTosa)
29+
list(APPEND linked_libs TorchMLIRTorchToTosa TorchMLIRTorchToTosaLinalg)
2930
endif()
3031

3132
add_mlir_library(TorchMLIRConversionPasses

lib/Conversion/Passes.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#ifdef TORCH_MLIR_ENABLE_TOSA
2424
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
25+
#include "torch-mlir/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.h"
2526
#endif // TORCH_MLIR_ENABLE_TOSA
2627

2728
//===----------------------------------------------------------------------===//

lib/Conversion/TorchToLinalg/TorchToLinalg.cpp

Lines changed: 24 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,8 @@ class ConvertTorchToLinalg
6363

6464
RewritePatternSet patterns(context);
6565

66-
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
67-
typeConverter, patterns, target);
68-
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,
69-
target);
70-
torch_to_linalg::populatePoolingPatternsAndLegality(typeConverter, patterns,
71-
target);
72-
torch_to_linalg::populateRandomPatternsAndLegality(typeConverter, patterns,
73-
target);
74-
torch_to_linalg::populateUncategorizedPatternsAndLegality(typeConverter,
75-
patterns, target);
76-
torch_to_linalg::populateReductionPatternsAndLegality(typeConverter,
77-
patterns, target);
78-
torch_to_linalg::populateDataMovementPatternsAndLegality(typeConverter,
66+
torch::populateTorchToLinalgOnTensorsPatternsAndLegality(typeConverter,
7967
patterns, target);
80-
torch_to_linalg::populateIndirectDataMovementPatternsAndLegality(
81-
typeConverter, patterns, target);
82-
torch_to_linalg::populateTensorConstructorsPatternsAndLegality(
83-
typeConverter, patterns, target);
8468

8569
if (failed(applyPartialConversion(getOperation(), target,
8670
std::move(patterns))))
@@ -89,6 +73,29 @@ class ConvertTorchToLinalg
8973
};
9074
} // namespace
9175

76+
void mlir::torch::populateTorchToLinalgOnTensorsPatternsAndLegality(
77+
TypeConverter &typeConverter, RewritePatternSet &patterns,
78+
ConversionTarget &target) {
79+
torch_to_linalg::populateTensorScalarInteropPatternsAndLegality(
80+
typeConverter, patterns, target);
81+
torch_to_linalg::populateLinearPatternsAndLegality(typeConverter, patterns,
82+
target);
83+
torch_to_linalg::populatePoolingPatternsAndLegality(typeConverter, patterns,
84+
target);
85+
torch_to_linalg::populateRandomPatternsAndLegality(typeConverter, patterns,
86+
target);
87+
torch_to_linalg::populateUncategorizedPatternsAndLegality(typeConverter,
88+
patterns, target);
89+
torch_to_linalg::populateReductionPatternsAndLegality(typeConverter, patterns,
90+
target);
91+
torch_to_linalg::populateDataMovementPatternsAndLegality(typeConverter,
92+
patterns, target);
93+
torch_to_linalg::populateIndirectDataMovementPatternsAndLegality(
94+
typeConverter, patterns, target);
95+
torch_to_linalg::populateTensorConstructorsPatternsAndLegality(
96+
typeConverter, patterns, target);
97+
}
98+
9299
std::unique_ptr<OperationPass<func::FuncOp>>
93100
mlir::torch::createConvertTorchToLinalgPass() {
94101
return std::make_unique<ConvertTorchToLinalg>();
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
add_mlir_conversion_library(TorchMLIRTorchToTosaLinalg
2+
TorchToTosaLinalg.cpp
3+
4+
ADDITIONAL_HEADER_DIRS
5+
${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTosaLinalg
6+
7+
DEPENDS
8+
TorchMLIRConversionPassIncGen
9+
10+
LINK_LIBS PUBLIC
11+
MLIRIR
12+
MLIRPass
13+
MLIRTosaDialect
14+
MLIRLinalgDialect
15+
TorchMLIRConversionUtils
16+
TorchMLIRTorchDialect
17+
)
18+
19+
torch_mlir_target_includes(TorchMLIRTorchToTosaLinalg)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
// Also available under a BSD-style license. See LICENSE.
7+
//
8+
//===----------------------------------------------------------------------===//
9+
10+
#include "torch-mlir/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.h"
11+
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
12+
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
13+
14+
#include "../PassDetail.h"
15+
#include "mlir/Dialect/Arith/IR/Arith.h"
16+
#include "mlir/Dialect/Complex/IR/Complex.h"
17+
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
18+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
19+
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/SCF/IR/SCF.h"
21+
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
22+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
23+
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
24+
#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h"
25+
#include "torch-mlir/Dialect/Torch/IR/TorchOps.h"
26+
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
27+
#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h"
28+
29+
using namespace mlir;
30+
using namespace mlir::torch;
31+
using namespace mlir::torch::Torch;
32+
33+
// -----------------------------------------------------------------------------
34+
// The pass
35+
// -----------------------------------------------------------------------------
36+
37+
namespace {
38+
class ConvertTorchToTosaLinalg
39+
: public ConvertTorchToTosaLinalgBase<ConvertTorchToTosaLinalg> {
40+
public:
41+
void getDependentDialects(DialectRegistry &registry) const override {
42+
registry.insert<linalg::LinalgDialect>();
43+
registry.insert<math::MathDialect>();
44+
registry.insert<func::FuncDialect>();
45+
registry.insert<tensor::TensorDialect>();
46+
registry.insert<tosa::TosaDialect>();
47+
registry.insert<arith::ArithDialect>();
48+
registry.insert<cf::ControlFlowDialect>();
49+
registry.insert<scf::SCFDialect>();
50+
registry.insert<complex::ComplexDialect>();
51+
TorchConversion::getBackendTypeConversionDependentDialects(registry);
52+
}
53+
54+
void runOnOperation() override {
55+
MLIRContext *context = &getContext();
56+
ConversionTarget target(*context);
57+
target.addLegalDialect<linalg::LinalgDialect, func::FuncDialect,
58+
cf::ControlFlowDialect, math::MathDialect,
59+
scf::SCFDialect, sparse_tensor::SparseTensorDialect,
60+
tosa::TosaDialect, tensor::TensorDialect,
61+
arith::ArithDialect, complex::ComplexDialect>();
62+
target.addLegalOp<TorchConversion::GetNextSeedOp>();
63+
torch::populateTorchToTosaConversionLegalOps(target);
64+
65+
TypeConverter typeConverter;
66+
typeConverter.addConversion([](Type type) { return type; });
67+
TorchConversion::setupBackendTypeConversion(target, typeConverter);
68+
69+
RewritePatternSet patterns(context);
70+
71+
torch::populateTorchToTosaConversionPatterns(typeConverter, patterns);
72+
torch::populateTorchToLinalgOnTensorsPatternsAndLegality(typeConverter,
73+
patterns, target);
74+
75+
if (failed(applyPartialConversion(getOperation(), target,
76+
std::move(patterns))))
77+
return signalPassFailure();
78+
}
79+
};
80+
81+
} // namespace
82+
83+
std::unique_ptr<OperationPass<func::FuncOp>>
84+
mlir::torch::createConvertTorchToTosaLinalgPass() {
85+
return std::make_unique<ConvertTorchToTosaLinalg>();
86+
}

0 commit comments

Comments
 (0)