Skip to content

Commit bd1d87e

Browse files
[mlir][bufferization][NFC] Remove layout post processing step
The layout postprocessing step was removed and is now part of the FuncOp bufferization. If the user specified a certain layout map for a tensor function arg, use that layout map directly when bufferizing the function signature. Previously, the bufferization used a generic layout map for every tensor function arg and then updated function signatures and CallOps in a separate step. Differential Revision: https://reviews.llvm.org/D122228
1 parent 70777d9 commit bd1d87e

File tree

1 file changed

+16
-104
lines changed

1 file changed

+16
-104
lines changed

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 16 additions & 104 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88
//
9-
// Module Bufferization is an extension of Comprehensive Bufferize that
9+
// Module Bufferization is an extension of One-Shot Bufferize that
1010
// bufferizes function boundaries. It provides `BufferizableOpInterface`
1111
// implementations for FuncOp, CallOp and ReturnOp.
1212
//
@@ -357,14 +357,27 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
357357
}
358358

359359
/// Return the index-th bufferized function argument type. This assumes that the
360-
/// specified argument is a tensor.
360+
/// specified argument is a tensor. If the tensor is ranked, a layout map may be
361+
/// specified by the user. If no layout map is specified, a fully dynamic map is
362+
/// used.
361363
static BaseMemRefType
362364
getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index,
363365
const BufferizationOptions &options) {
364366
auto tensorType =
365367
funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
366368
assert(tensorType && "expected TensorType");
367-
return getMemRefType(tensorType, options);
369+
BaseMemRefType memrefType = getMemRefType(tensorType, options);
370+
371+
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
372+
index, BufferizableOpInterface::kBufferLayoutAttrName);
373+
if (!layoutAttr)
374+
return memrefType;
375+
376+
auto rankedMemrefType = memrefType.dyn_cast<MemRefType>();
377+
assert(rankedMemrefType && "buffer layout not supported on unranked tensors");
378+
return MemRefType::get(
379+
rankedMemrefType.getShape(), rankedMemrefType.getElementType(),
380+
layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt());
368381
}
369382

370383
/// Gather equivalence info of CallOps.
@@ -451,103 +464,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
451464
return success();
452465
}
453466

454-
static void foreachCaller(const FuncCallerMap &callerMap, func::FuncOp callee,
455-
llvm::function_ref<void(Operation *)> doit) {
456-
auto itCallers = callerMap.find(callee);
457-
if (itCallers == callerMap.end())
458-
return;
459-
for (Operation *caller : itCallers->second)
460-
doit(caller);
461-
}
462-
463-
/// Postprocess the linalg.buffer_layout annotation across function boundaries.
464-
/// This is a purely mechanical process that may later become part of a
465-
/// separate pass with its own layout assignment heuristic.
466-
static void layoutPostProcessing(ModuleOp moduleOp) {
467-
SmallVector<func::FuncOp> orderedFuncOps;
468-
DenseMap<func::FuncOp, DenseSet<Operation *>> callerMap;
469-
auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap);
470-
(void)res;
471-
assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure");
472-
473-
for (func::FuncOp funcOp : orderedFuncOps) {
474-
DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
475-
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
476-
operandsPerCaller.try_emplace(caller, SmallVector<Value>());
477-
});
478-
479-
SmallVector<Type> argumentTypes;
480-
// Iterate on each function argument and check it it was marked with a
481-
// desired layout.
482-
for (const auto &it :
483-
llvm::enumerate(funcOp.getFunctionType().getInputs())) {
484-
int argNumber = it.index();
485-
Type inputType = it.value();
486-
auto memrefType = inputType.dyn_cast<MemRefType>();
487-
auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>(
488-
argNumber, BufferizableOpInterface::kBufferLayoutAttrName);
489-
AffineMap desiredLayoutMap =
490-
layoutAttr ? layoutAttr.getValue() : AffineMap();
491-
AffineMap currentLayoutMap =
492-
memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap();
493-
if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
494-
argumentTypes.push_back(inputType);
495-
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
496-
operandsPerCaller.find(caller)->getSecond().push_back(
497-
caller->getOperand(argNumber));
498-
});
499-
continue;
500-
}
501-
502-
// Compute the buffer type with desired layout and add to input argument
503-
// types.
504-
MemRefType desiredMemrefType = MemRefType::get(
505-
memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap);
506-
argumentTypes.push_back(desiredMemrefType);
507-
508-
// If funcOp's body is not empty, change the bbArg type and propagate.
509-
if (!funcOp.getBody().empty()) {
510-
BlockArgument bbArg = funcOp.getArgument(argNumber);
511-
bbArg.setType(desiredMemrefType);
512-
OpBuilder b(bbArg.getContext());
513-
b.setInsertionPointToStart(bbArg.getOwner());
514-
assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) &&
515-
"layoutPostProcessing: cast incompatible");
516-
// Cast back to the original memrefType and let it canonicalize.
517-
Value cast =
518-
b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg);
519-
bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp());
520-
}
521-
522-
// Cast to desired buffer type on all callers to `funcOp`.
523-
// TODO: on the callee side, this may even have to trigger a copy to
524-
// change the layout. For now let the memref::CastOp fail to verify in
525-
// such cases.
526-
auto castArg = [&](Operation *caller) {
527-
OpBuilder b(caller);
528-
assert(
529-
memref::CastOp::areCastCompatible(
530-
caller->getOperand(argNumber).getType(), desiredMemrefType) &&
531-
"layoutPostProcessing.2: cast incompatible");
532-
Value newOperand = b.create<memref::CastOp>(
533-
funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber));
534-
operandsPerCaller.find(caller)->getSecond().push_back(newOperand);
535-
};
536-
foreachCaller(callerMap, funcOp, castArg);
537-
}
538-
539-
// Set operands with cast buffer on all callers to `funcOp`.
540-
foreachCaller(callerMap, funcOp, [&](Operation *caller) {
541-
caller->setOperands(operandsPerCaller.lookup(caller));
542-
});
543-
544-
// Finally set the funcOp type to update the arguments.
545-
auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes,
546-
funcOp.getFunctionType().getResults());
547-
funcOp.setType(newFuncType);
548-
}
549-
}
550-
551467
namespace mlir {
552468
namespace linalg {
553469
namespace comprehensive_bufferize {
@@ -1111,10 +1027,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
11111027
if (failed(finalizeBuffers(moduleOp, options)))
11121028
return failure();
11131029

1114-
// Perform a post-processing pass of layout modification at function boundary
1115-
// according to the kBufferLayoutAttrName.
1116-
layoutPostProcessing(moduleOp);
1117-
11181030
// Post-pass cleanup of inplaceable and buffer_layout attributes.
11191031
moduleOp.walk([&](func::FuncOp op) {
11201032
for (BlockArgument bbArg : op.getArguments())

0 commit comments

Comments
 (0)