Skip to content

Commit 45998aa

Browse files
committed
Lower from tm_tensor to linalg in tosa_linalg pipeline.
1 parent 0f796b1 commit 45998aa

File tree

6 files changed

+106
-26
lines changed

6 files changed

+106
-26
lines changed

lib/Conversion/TorchToTosaLinalg/TorchToTosaLinalg.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,8 +60,6 @@ class ConvertTorchToTosaLinalg
6060
tosa::TosaDialect, tensor::TensorDialect,
6161
arith::ArithDialect, complex::ComplexDialect>();
6262

63-
target.addIllegalDialect<Torch::TorchDialect>();
64-
6563
target.addLegalOp<TorchConversion::GetNextSeedOp>();
6664
torch::populateTorchToTosaConversionLegalOps(target);
6765

@@ -71,7 +69,13 @@ class ConvertTorchToTosaLinalg
7169

7270
RewritePatternSet patterns(context);
7371

74-
torch::populateTorchToTosaConversionPatterns(typeConverter, patterns);
72+
auto illegalOps = populateTorchToTosaConversionPatternsAndIllegalOps(
73+
typeConverter, patterns);
74+
75+
// for (auto op : illegalOps) {
76+
// target.addIllegalOp(OperationName(op, context));
77+
// }
78+
7579
torch::populateTorchToLinalgOnTensorsPatternsAndLegality(typeConverter,
7680
patterns, target);
7781

lib/Dialect/TorchConversion/Transforms/Passes.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
#include "mlir/Dialect/MemRef/Transforms/Passes.h"
1313
#include "mlir/Pass/PassManager.h"
1414
#include "mlir/Transforms/Passes.h"
15+
#include "torch-mlir-dialects/Dialect/TMTensor/Transforms/Passes.h"
1516
#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h"
1617
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
1718
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
@@ -192,6 +193,16 @@ void TorchConversion::createTorchBackendToTosaLinalgBackendPipeline(
192193
// The resolution of `dim` ops tends to create identical ops. CSE them.
193194
pm.addNestedPass<func::FuncOp>(createCSEPass());
194195

196+
// `tm-tensor-to-loops` pass can convert TMTensor ops to Linalg and other MLIR
197+
// core dialects only when the operand types to the TMTensor ops are of type
198+
// `memref,` so run the `tm-tensor-bufferize` pass before running
199+
// `tm-tensor-to-loops.` Unfortunately, we have to lower to `memref` types
200+
// this early due to the `tm-tensor-to-loops` pass limitation. Ideally, we
201+
// would have liked to keep `tensor` types at this stage and defer lowering to
202+
// `memref` types as late as possible.
203+
pm.addNestedPass<func::FuncOp>(TMTensor::createTMTensorBufferizePass());
204+
pm.addNestedPass<func::FuncOp>(TMTensor::createTMTensorToLoopsPass());
205+
195206
// Finish the type conversion from `torch` types to the types of the
196207
// TOSA backend contract.
197208
pm.addPass(TorchConversion::createFuncBackendTypeConversionPass());

lib/Dialect/TorchConversion/Transforms/VerifyTosaLinalgBackendContract.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,19 @@
1010
#include "PassDetail.h"
1111

1212
#include "mlir/Dialect/Affine/IR/AffineOps.h"
13+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1314
#include "mlir/Dialect/Complex/IR/Complex.h"
1415
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
1516
#include "mlir/Dialect/Func/IR/FuncOps.h"
1617
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1718
#include "mlir/Dialect/MLProgram/IR/MLProgram.h"
1819
#include "mlir/Dialect/Math/IR/Math.h"
20+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1921
#include "mlir/Dialect/SCF/IR/SCF.h"
2022
#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h"
2123
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2224
#include "mlir/Dialect/Tosa/IR/TosaOps.h"
2325
#include "mlir/Transforms/DialectConversion.h"
24-
#include "torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorDialect.h"
2526
#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h"
2627
#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h"
2728

@@ -30,7 +31,6 @@
3031
using namespace mlir;
3132
using namespace mlir::torch;
3233
using namespace mlir::torch::TorchConversion;
33-
using namespace TMTensor;
3434

3535
namespace {
3636
class VerifyTosaLinalgBackendContractPass
@@ -40,8 +40,13 @@ class VerifyTosaLinalgBackendContractPass
4040
MLIRContext *context = &getContext();
4141
auto module = getOperation();
4242
TypeConverter converter;
43-
converter.addConversion([](RankedTensorType type) -> Type {
44-
if (BaseMemRefType::isValidElementType(type.getElementType()))
43+
converter.addConversion([](Type type) -> Type {
44+
auto elemTy = type;
45+
if (isa<TensorType>(type))
46+
elemTy = cast<TensorType>(type).getElementType();
47+
if (isa<MemRefType>(type))
48+
elemTy = cast<MemRefType>(type).getElementType();
49+
if (BaseMemRefType::isValidElementType(elemTy))
4550
return type;
4651
return nullptr;
4752
});
@@ -72,6 +77,8 @@ class VerifyTosaLinalgBackendContractPass
7277
target.addDynamicallyLegalDialect<func::FuncDialect>(isLegalScalarOp);
7378
target.addDynamicallyLegalDialect<math::MathDialect>(isLegalScalarOp);
7479
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegalScalarOp);
80+
target.addDynamicallyLegalDialect<bufferization::BufferizationDialect>(
81+
opHasLegalTypes);
7582
target.addDynamicallyLegalDialect<complex::ComplexDialect>(isLegalScalarOp);
7683

7784
// Tensor operations should go through linalg and the tensor dialect.
@@ -83,8 +90,8 @@ class VerifyTosaLinalgBackendContractPass
8390
target.addDynamicallyLegalDialect<tosa::TosaDialect>(opHasLegalTypes);
8491
target.addDynamicallyLegalDialect<affine::AffineDialect>(opHasLegalTypes);
8592
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(opHasLegalTypes);
86-
target.addDynamicallyLegalDialect<TMTensorDialect>(opHasLegalTypes);
8793
target.addDynamicallyLegalDialect<scf::SCFDialect>(opHasLegalTypes);
94+
target.addDynamicallyLegalDialect<memref::MemRefDialect>(opHasLegalTypes);
8895
target.addDynamicallyLegalDialect<ml_program::MLProgramDialect>(
8996
opHasLegalTypes);
9097

@@ -97,7 +104,8 @@ class VerifyTosaLinalgBackendContractPass
97104
// We avoid `module.emitError()` so that mlir-print-op-on-diagnostics
98105
// doesn't unnecessarily spew out the entire module.
99106
emitError(module.getLoc())
100-
<< "Module does not conform to the linalg-on-tensors backend "
107+
<< "Module does not conform to the "
108+
"torch-backend-to-tosa-linalg-backend "
101109
"contract. "
102110
"See dialect conversion legality information above.";
103111
return signalPassFailure();

test/Conversion/TorchToTosa/torch-backend-to-tosa-linalg-backend-pipeline.mlir

Lines changed: 0 additions & 17 deletions
This file was deleted.
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: torch-mlir-opt <%s -convert-torch-to-tosa-linalg -split-input-file -verify-diagnostics | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @torch.aten.avg_pool2d.divisor_override
4+
// CHECK: linalg.pooling_nchw_sum
5+
// CHECK-NOT: torch.aten.avg_pool2d
6+
func.func @torch.aten.avg_pool2d.divisor_override(%arg0: !torch.vtensor<[1,192,35,35],f32>) -> !torch.vtensor<[1,192,35,35],f32> {
7+
%int0 = torch.constant.int 0
8+
%int1 = torch.constant.int 1
9+
%int3 = torch.constant.int 3
10+
%false= torch.constant.bool false
11+
%count_include_pad = torch.constant.bool false
12+
%divisor_override = torch.constant.int 9
13+
14+
%0 = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list<int>
15+
%1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
16+
%2 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list<int>
17+
%3 = torch.aten.avg_pool2d %arg0, %0, %1, %2, %false, %count_include_pad, %divisor_override : !torch.vtensor<[1,192,35,35],f32>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.bool, !torch.int -> !torch.vtensor<[1,192,35,35],f32>
18+
return %3 : !torch.vtensor<[1,192,35,35],f32>
19+
}
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// RUN: torch-mlir-opt -pass-pipeline='builtin.module(torch-backend-to-tosa-linalg-backend-pipeline)' -split-input-file -verify-diagnostics %s | FileCheck %s
2+
3+
//-----
4+
5+
// CHECK-LABEL: func.func @torch.aten.size.int(
6+
// CHECK-SAME: %[[ARG0:.*]]: tensor<4x2xf32>) -> i64 {
7+
// CHECK: %[[VAL_0:.*]] = arith.constant false
8+
// CHECK: %[[VAL_1:.*]] = arith.constant 2 : index
9+
// CHECK: cf.assert %[[VAL_0]], "dim must be smaller than inputRank"
10+
// CHECK: %[[VAL_2:.*]] = tensor.dim %[[ARG0]], %[[VAL_1]] : tensor<4x2xf32>
11+
// CHECK: %[[VAL_3:.*]] = arith.index_cast %[[VAL_2]] : index to i64
12+
// CHECK: return %[[VAL_3]] : i64
13+
func.func @torch.aten.size.int(%arg0: !torch.vtensor<[4,2],f32>) -> !torch.int {
14+
%c2 = torch.constant.int 2
15+
%0 = torch.aten.size.int %arg0, %c2 : !torch.vtensor<[4,2],f32>, !torch.int -> !torch.int
16+
return %0 : !torch.int
17+
}
18+
19+
//-----
20+
21+
// CHECK-LABEL: func.func @tm_scan(
22+
// CHECK-SAME: %[[ARG0:.*]]: tensor<1x512xi64>) -> (tensor<1x512xi64>, tensor<1xi64>) {
23+
// CHECK: %[[VAL_1:.*]] = arith.constant 512 : index
24+
// CHECK: %[[VAL_2:.*]] = arith.constant 1 : index
25+
// CHECK: %[[VAL_3:.*]] = arith.constant 0 : index
26+
// CHECK: %[[VAL_4:.*]] = memref.alloc() : memref<1x512xi64>
27+
// CHECK: %[[VAL_5:.*]] = bufferization.to_tensor %[[VAL_4]] : memref<1x512xi64>
28+
// CHECK: %[[VAL_6:.*]] = memref.alloc() : memref<1xi64>
29+
// CHECK: %[[VAL_7:.*]] = bufferization.to_tensor %[[VAL_6]] : memref<1xi64>
30+
// CHECK: scf.for %[[VAL_8:.*]] = %[[VAL_3]] to %[[VAL_1]] step %[[VAL_2]] {
31+
// CHECK: %[[VAL_9:.*]] = arith.cmpi eq, %[[VAL_8]], %[[VAL_3]] : index
32+
// CHECK: scf.if %[[VAL_9]] {
33+
// CHECK: %[[VAL_10:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_3]], %[[VAL_8]]] : tensor<1x512xi64>
34+
// CHECK: memref.store %[[VAL_10]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_8]]] : memref<1x512xi64>
35+
// CHECK: } else {
36+
// CHECK: %[[VAL_11:.*]] = arith.subi %[[VAL_8]], %[[VAL_2]] : index
37+
// CHECK: %[[VAL_12:.*]] = memref.load %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_11]]] : memref<1x512xi64>
38+
// CHECK: %[[VAL_13:.*]] = tensor.extract %[[ARG0]]{{\[}}%[[VAL_3]], %[[VAL_8]]] : tensor<1x512xi64>
39+
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_12]], %[[VAL_13]] : i64
40+
// CHECK: memref.store %[[VAL_14]], %[[VAL_4]]{{\[}}%[[VAL_3]], %[[VAL_8]]] : memref<1x512xi64>
41+
// CHECK: memref.store %[[VAL_14]], %[[VAL_6]]{{\[}}%[[VAL_3]]] : memref<1xi64>
42+
// CHECK: }
43+
// CHECK: }
44+
// CHECK: return %[[VAL_5]], %[[VAL_7]] : tensor<1x512xi64>, tensor<1xi64>
45+
// CHECK: }
46+
func.func @tm_scan(%arg0: tensor<1x512xi64>) -> (tensor<1x512xi64>, tensor<1xi64>) {
47+
%0 = tensor.empty() : tensor<1x512xi64>
48+
%1 = tensor.empty() : tensor<1xi64>
49+
%2:2 = tm_tensor.scan dimension(1) inclusive(true) ins(%arg0 : tensor<1x512xi64>) outs(%0, %1 : tensor<1x512xi64>, tensor<1xi64>) {
50+
^bb0(%arg1: i64, %arg2: i64):
51+
%3 = arith.addi %arg1, %arg2 : i64
52+
tm_tensor.yield %3 : i64
53+
} -> tensor<1x512xi64>, tensor<1xi64>
54+
return %2#0, %2#1 : tensor<1x512xi64>, tensor<1xi64>
55+
}

0 commit comments

Comments
 (0)