Skip to content

Commit cea6647

Browse files
nicolasvasilachememfrob
authored andcommitted
[mlir][Linalg] Add layout specification support to bufferization.
Previously, linalg bufferization always had to be conservative at function boundaries and assume the most dynamic strided memref layout. This revision introduce the mechanism to specify a linalg.buffer_layout function argument attribute that carries an affine map used to set a less pessimistic layout. Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D105859
1 parent 5b4dad2 commit cea6647

File tree

3 files changed

+144
-3
lines changed

3 files changed

+144
-3
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def Linalg_Dialect : Dialect {
4848
constexpr const static ::llvm::StringLiteral
4949
kInplaceableAttrName = "linalg.inplaceable";
5050

51+
/// Attribute name used to mark the bufferization layout for region
52+
// arguments during linalg comprehensive bufferization.
53+
constexpr const static ::llvm::StringLiteral
54+
kBufferLayoutAttrName = "linalg.buffer_layout";
55+
5156
using RegionBuilderFunType =
5257
llvm::function_ref<void(ImplicitLocOpBuilder &b, Block &)>;
5358
RegionBuilderFunType getRegionBuilder(StringRef name) {

mlir/lib/Dialect/Linalg/Transforms/ComprehensiveBufferize.cpp

Lines changed: 99 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -324,8 +324,10 @@ setInPlaceFuncArgument(BlockArgument bbArg,
324324

325325
/// Remove the attribute that triggers inplace bufferization on a FuncOp
326326
/// argument `bbArg`.
327-
static void removeInPlaceFuncArgument(BlockArgument bbArg) {
327+
static void removeBufferizationFuncArguments(BlockArgument bbArg) {
328328
auto funcOp = cast<FuncOp>(bbArg.getOwner()->getParentOp());
329+
funcOp.removeArgAttr(bbArg.getArgNumber(),
330+
LinalgDialect::kBufferLayoutAttrName);
329331
funcOp.removeArgAttr(bbArg.getArgNumber(),
330332
LinalgDialect::kInplaceableAttrName);
331333
}
@@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
26082610
(void)applyPatternsAndFoldGreedily(moduleOp, std::move(patterns));
26092611
}
26102612

2613+
static void
2614+
foreachCaller(const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
2615+
FuncOp callee, llvm::function_ref<void(Operation *)> doit) {
2616+
auto itCallers = callerMap.find(callee);
2617+
if (itCallers == callerMap.end())
2618+
return;
2619+
for (Operation *caller : itCallers->second)
2620+
doit(caller);
2621+
}
2622+
2623+
/// Postprocess the linalg.buffer_layout annotation across function boundaries.
2624+
/// This is a purely mechanical process that may later become part of a
2625+
/// separate pass with its own layout assignment heuristic.
2626+
static void layoutPostProcessing(ModuleOp moduleOp) {
2627+
SmallVector<FuncOp> orderedFuncOps;
2628+
DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
2629+
auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
2630+
assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
2631+
2632+
for (FuncOp funcOp : orderedFuncOps) {
2633+
DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
2634+
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
2635+
operandsPerCaller.try_emplace(caller, SmallVector<Value>());
2636+
});
2637+
2638+
SmallVector<Type> argumentTypes;
2639+
// Iterate on each function argument and check it it was marked with a
2640+
// desired layout.
2641+
for (auto it : llvm::enumerate(funcOp.getType().getInputs())) {
2642+
int argNumber = it.index();
2643+
Type inputType = it.value();
2644+
auto memrefType = inputType.dyn_cast<MemRefType>();
2645+
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
2646+
argNumber, LinalgDialect::kBufferLayoutAttrName);
2647+
AffineMap desiredLayoutMap =
2648+
layoutAttr ? layoutAttr.getValue() : AffineMap();
2649+
AffineMap currentLayoutMap =
2650+
memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
2651+
if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
2652+
argumentTypes.push_back(inputType);
2653+
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
2654+
operandsPerCaller.find(caller)->getSecond().push_back(
2655+
caller->getOperand(argNumber));
2656+
});
2657+
continue;
2658+
}
2659+
2660+
// Compute the buffer type with desired layout and add to input argument
2661+
// types.
2662+
MemRefType desiredMemrefType = MemRefType::get(
2663+
memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
2664+
argumentTypes.push_back(desiredMemrefType);
2665+
2666+
// If funcOp's body is not empty, change the bbArg type and propagate.
2667+
if (!funcOp.body().empty()) {
2668+
BlockArgument bbArg = funcOp.getArgument(argNumber);
2669+
bbArg.setType(desiredMemrefType);
2670+
OpBuilder b(bbArg.getContext());
2671+
b.setInsertionPointToStart(bbArg.getOwner());
2672+
// Cast back to the original memrefType and let it canonicalize.
2673+
Value cast =
2674+
b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
2675+
bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
2676+
}
2677+
2678+
// Cast to desired buffer type on all callers to `funcOp`.
2679+
// TODO: on the callee side, this may even have to trigger a copy to
2680+
// change the layout. For now let the memref::CastOp fail to verify in
2681+
// such cases.
2682+
auto castArg = [&](Operation *caller) {
2683+
OpBuilder b(caller);
2684+
Value newOperand = b.create<memref::CastOp>(
2685+
funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
2686+
operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
2687+
};
2688+
foreachCaller(callerMap, funcOp, castArg);
2689+
}
2690+
2691+
// Set operands with cast buffer on all callers to `funcOp`.
2692+
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
2693+
caller->setOperands(operandsPerCaller.lookup(caller));
2694+
});
2695+
2696+
// Finally set the funcOp type to update the arguments.
2697+
auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
2698+
funcOp.getType().getResults());
2699+
funcOp.setType(newFuncType);
2700+
}
2701+
}
2702+
26112703
void LinalgComprehensiveModuleBufferize::runOnOperation() {
26122704
ModuleOp moduleOp = getOperation();
26132705
applyEnablingTransformations(moduleOp);
@@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
26722764
}
26732765
}
26742766

2675-
// Post-pass cleanup of inplaceable attributes.
2767+
// Perform a post-processing pass of layout modification at function boundary
2768+
// according to the kBufferLayoutAttrName.
2769+
layoutPostProcessing(moduleOp);
2770+
2771+
// Post-pass cleanup of inplaceable and buffer_layout attributes.
26762772
moduleOp.walk(
26772773
[&](Operation *op) { op->removeAttr(kInPlaceResultsAttrName); });
26782774
moduleOp.walk([&](FuncOp op) {
26792775
for (BlockArgument bbArg : op.getArguments())
2680-
removeInPlaceFuncArgument(bbArg);
2776+
removeBufferizationFuncArguments(bbArg);
26812777
});
26822778

26832779
OpPassManager cleanupPipeline(OpPassManager("module"));

mlir/test/Dialect/Linalg/comprehensive-module-bufferize.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,3 +555,43 @@ func @tiled_dot(%A: tensor<?xf32>, %B: tensor<?xf32>, %c: tensor<f32> {linalg.in
555555
// CHECK-NOT: tensor
556556
return %1 : tensor<f32>
557557
}
558+
559+
// -----
560+
561+
// CHECK: #[[$DYNAMIC:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
562+
563+
// CHECK: func private @external_func(memref<?xf32, #[[$DYNAMIC]]>)
564+
func private @external_func(tensor<?xf32>)
565+
566+
// CHECK: func @callee(
567+
// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
568+
// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
569+
// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
570+
func @callee(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
571+
%B : tensor<?xf32>,
572+
%C : tensor<?xf32>) {
573+
// CHECK-NEXT: %[[CASTED:.*]] = memref.cast %[[A]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
574+
// CHECK-NEXT: call @external_func(%[[CASTED]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
575+
call @external_func(%A) : (tensor<?xf32>) -> ()
576+
577+
// CHECK-NEXT: call @external_func(%[[B]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
578+
call @external_func(%B) : (tensor<?xf32>) -> ()
579+
580+
// CHECK-NEXT: call @external_func(%[[C]]) : (memref<?xf32, #[[$DYNAMIC]]>) -> ()
581+
call @external_func(%C) : (tensor<?xf32>) -> ()
582+
583+
return
584+
}
585+
586+
// CHECK: func @entry(
587+
// CHECK-SAME: %[[A:[0-9a-zA-Z]*]]: memref<?xf32>
588+
// CHECK-SAME: %[[B:[0-9a-zA-Z]*]]: memref<?xf32>
589+
// CHECK-SAME: %[[C:[0-9a-zA-Z]*]]: memref<?xf32, #[[$DYNAMIC]]>
590+
func @entry(%A : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
591+
%B : tensor<?xf32> {linalg.buffer_layout = affine_map<(i)[s0, s1] -> (i)>},
592+
%C : tensor<?xf32>) {
593+
// CHECK-NEXT: %[[CASTED_B:.*]] = memref.cast %[[B]] : memref<?xf32> to memref<?xf32, #[[$DYNAMIC]]>
594+
// CHECK-NEXT: call @callee(%[[A]], %[[CASTED_B]], %[[C]])
595+
call @callee(%A, %B, %C) : (tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> ()
596+
return
597+
}

0 commit comments

Comments
 (0)