Skip to content

Commit 6fc092f

Browse files
authored
[mlir][bufferization] Let bufferization.tensor_layout be any layout attr (#138567)
The bufferization.tensor_layout is unnecessarily restricted to affine map attributes when it could reasonably be any implementor of MemRefLayoutAttrInterface.
1 parent a10f6c1 commit 6fc092f

File tree

3 files changed

+28
-6
lines changed

3 files changed

+28
-6
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationDialect.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,9 @@ LogicalResult BufferizationDialect::verifyRegionArgAttribute(
122122
return success();
123123
}
124124
if (attr.getName() == kBufferLayoutAttrName) {
125-
if (!llvm::isa<AffineMapAttr>(attr.getValue())) {
125+
if (!llvm::isa<MemRefLayoutAttrInterface>(attr.getValue())) {
126126
return op->emitError() << "'" << kBufferLayoutAttrName
127-
<< "' is expected to be a affine map attribute";
127+
<< "' is expected to be a memref layout attribute";
128128
}
129129
if (!isa<FunctionOpInterface>(op))
130130
return op->emitError() << "expected '" << kBufferLayoutAttrName

mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,16 +63,16 @@ getBufferizedFunctionArgType(FuncOp funcOp, int64_t index,
6363
BaseMemRefType memrefType = options.functionArgTypeConverterFn(
6464
tensorType, *options.defaultMemorySpaceFn(tensorType), funcOp, options);
6565

66-
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
66+
auto layoutAttr = funcOp.getArgAttrOfType<MemRefLayoutAttrInterface>(
6767
index, BufferizationDialect::kBufferLayoutAttrName);
6868
if (!layoutAttr)
6969
return memrefType;
7070

7171
auto rankedMemrefType = dyn_cast<MemRefType>(memrefType);
7272
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
73-
return MemRefType::get(
74-
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
75-
layoutAttr.getValue(), rankedMemrefType.getMemorySpace());
73+
return MemRefType::get(rankedMemrefType.getShape(),
74+
rankedMemrefType.getElementType(), layoutAttr,
75+
rankedMemrefType.getMemorySpace());
7676
}
7777

7878
/// Return the FuncOp called by `callOp`.

mlir/test/Dialect/Tensor/one-shot-bufferize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,28 @@ func.func @cast_retains_buffer_layout(
353353

354354
// -----
355355

356+
// CHECK-LABEL: func.func @cast_retains_buffer_layout_strided(
357+
// CHECK-SAME: %[[t:.*]]: memref<?xf32, strided<[1], offset: 5>>, %[[sz:.*]]: index) -> memref<?xf32, strided<[1], offset: 7>> {
358+
// CHECK: %[[casted:.*]] = memref.cast %[[t]] : memref<?xf32, strided<[1], offset: 5>> to memref<10xf32, strided<[1], offset: 5>>
359+
// CHECK: %[[slice:.*]] = memref.subview %[[casted]][2] [%[[sz]]] [1] : memref<10xf32, strided<[1], offset: 5>> to memref<?xf32, strided<[1], offset: 7>>
360+
// CHECK: return %[[slice]]
361+
func.func @cast_retains_buffer_layout_strided(
362+
%t: tensor<?xf32>
363+
{bufferization.buffer_layout = strided<[1], offset: 5>},
364+
%sz: index)
365+
-> (tensor<10xf32>, tensor<?xf32>)
366+
{
367+
%casted = tensor.cast %t : tensor<?xf32> to tensor<10xf32>
368+
%slice = tensor.extract_slice %casted[2][%sz][1] : tensor<10xf32> to tensor<?xf32>
369+
370+
// Note: The %casted return type is folded away because both buffers are
371+
// equivalent. Therefore, we currently loose some static type information
372+
// in the caller.
373+
return %casted, %slice : tensor<10xf32>, tensor<?xf32>
374+
}
375+
376+
// -----
377+
356378
// CHECK-LABEL: func.func @parallel_insert_slice_source_out_of_place
357379
func.func @parallel_insert_slice_source_out_of_place(%in: tensor<1xf32>, %out: tensor<100xf32>, %f: f32) {
358380
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)