Skip to content

Commit 85f843b

Browse files
nicolasvasilachememfrob
authored andcommitted
[mlir][Linalg] Better support for bufferizing non-tensor results.
Clean up corner cases related to elemental tensor / buffer type return values that would previously fail. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D105857
1 parent 535a270 commit 85f843b

File tree

2 files changed

+22
-0
lines changed

2 files changed

+22
-0
lines changed

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2379,6 +2379,8 @@ static Operation *getEquivalentAlloc(Value value,
23792379
static BlockArgument
23802380
getEquivalentEnclosingFuncBBArg(Value v,
23812381
const BufferizationAliasInfo &aliasInfo) {
2382+
if (!v.getType().isa<RankedTensorType>())
2383+
return nullptr;
23822384
Operation *op = v.getParentBlock()->getParentOp();
23832385
FuncOp funcOp = dyn_cast<FuncOp>(op);
23842386
if (!funcOp)
@@ -2455,6 +2457,12 @@ static LogicalResult bufferizeFuncOpBoundary(
24552457
// 1. For each FuncOp result, keep track of which inplace argument it reuses.
24562458
SmallVector<Value> returnValues;
24572459
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
2460+
// If not a renturn tensor type just forward it.
2461+
if (!returnOperand.get().getType().isa<RankedTensorType>()) {
2462+
returnValues.push_back(returnOperand.get());
2463+
continue;
2464+
}
2465+
24582466
// If return operand is equivalent to some bbArg, no need to return it.
24592467
Value returnVal = returnOperand.get();
24602468
if (getEquivalentEnclosingFuncBBArg(returnVal, aliasInfo))

mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,19 @@
11
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize -split-input-file | FileCheck %s
22

3+
// CHECK-LABEL: func @transfer_read(%{{.*}}: memref<?xf32, #map>) -> vector<4xf32> {
4+
func @transfer_read(%A : tensor<?xf32>) -> (vector<4xf32>) {
5+
%c0 = constant 0 : index
6+
%f0 = constant 0.0 : f32
7+
8+
// CHECK: %[[RES:.*]] = vector.transfer_read {{.*}} : memref<?xf32, #{{.*}}>, vector<4xf32>
9+
%0 = vector.transfer_read %A[%c0], %f0 : tensor<?xf32>, vector<4xf32>
10+
11+
// CHECK: return %[[RES]] : vector<4xf32>
12+
return %0 : vector<4xf32>
13+
}
14+
15+
// -----
16+
317
// CHECK-DAG: #[[$map_1d_dyn:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
418

519
// CHECK-LABEL: func @fill_inplace(

0 commit comments

Comments
 (0)