|
6 | 6 | //
|
7 | 7 | //===----------------------------------------------------------------------===//
|
8 | 8 | //
|
9 |
| -// Module Bufferization is an extension of Comprehensive Bufferize that |
| 9 | +// Module Bufferization is an extension of One-Shot Bufferize that |
10 | 10 | // bufferizes function boundaries. It provides `BufferizableOpInterface`
|
11 | 11 | // implementations for FuncOp, CallOp and ReturnOp.
|
12 | 12 | //
|
@@ -357,14 +357,27 @@ static func::FuncOp getCalledFunction(CallOpInterface callOp) {
|
357 | 357 | }
|
358 | 358 |
|
359 | 359 | /// Return the index-th bufferized function argument type. This assumes that the
|
360 |
| -/// specified argument is a tensor. |
| 360 | +/// specified argument is a tensor. If the tensor is ranked, a layout map may be |
| 361 | +/// specified by the user. If no layout map is specified, a fully dynamic map is |
| 362 | +/// used. |
361 | 363 | static BaseMemRefType
|
362 | 364 | getBufferizedFunctionArgType(func::FuncOp funcOp, int64_t index,
|
363 | 365 | const BufferizationOptions &options) {
|
364 | 366 | auto tensorType =
|
365 | 367 | funcOp.getFunctionType().getInput(index).dyn_cast<TensorType>();
|
366 | 368 | assert(tensorType && "expected TensorType");
|
367 |
| - return getMemRefType(tensorType, options); |
| 369 | + BaseMemRefType memrefType = getMemRefType(tensorType, options); |
| 370 | + |
| 371 | + auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( |
| 372 | + index, BufferizableOpInterface::kBufferLayoutAttrName); |
| 373 | + if (!layoutAttr) |
| 374 | + return memrefType; |
| 375 | + |
| 376 | + auto rankedMemrefType = memrefType.dyn_cast<MemRefType>(); |
| 377 | + assert(rankedMemrefType && "buffer layout not supported on unranked tensors"); |
| 378 | + return MemRefType::get( |
| 379 | + rankedMemrefType.getShape(), rankedMemrefType.getElementType(), |
| 380 | + layoutAttr.getValue(), rankedMemrefType.getMemorySpaceAsInt()); |
368 | 381 | }
|
369 | 382 |
|
370 | 383 | /// Gather equivalence info of CallOps.
|
@@ -451,103 +464,6 @@ getFuncOpsOrderedByCalls(ModuleOp moduleOp,
|
451 | 464 | return success();
|
452 | 465 | }
|
453 | 466 |
|
454 |
| -static void foreachCaller(const FuncCallerMap &callerMap, func::FuncOp callee, |
455 |
| - llvm::function_ref<void(Operation *)> doit) { |
456 |
| - auto itCallers = callerMap.find(callee); |
457 |
| - if (itCallers == callerMap.end()) |
458 |
| - return; |
459 |
| - for (Operation *caller : itCallers->second) |
460 |
| - doit(caller); |
461 |
| -} |
462 |
| - |
463 |
| -/// Postprocess the linalg.buffer_layout annotation across function boundaries. |
464 |
| -/// This is a purely mechanical process that may later become part of a |
465 |
| -/// separate pass with its own layout assignment heuristic. |
466 |
| -static void layoutPostProcessing(ModuleOp moduleOp) { |
467 |
| - SmallVector<func::FuncOp> orderedFuncOps; |
468 |
| - DenseMap<func::FuncOp, DenseSet<Operation *>> callerMap; |
469 |
| - auto res = getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, callerMap); |
470 |
| - (void)res; |
471 |
| - assert(succeeded(res) && "unexpected getFuncOpsOrderedByCalls failure"); |
472 |
| - |
473 |
| - for (func::FuncOp funcOp : orderedFuncOps) { |
474 |
| - DenseMap<Operation *, SmallVector<Value>> operandsPerCaller; |
475 |
| - foreachCaller(callerMap, funcOp, [&](Operation *caller) { |
476 |
| - operandsPerCaller.try_emplace(caller, SmallVector<Value>()); |
477 |
| - }); |
478 |
| - |
479 |
| - SmallVector<Type> argumentTypes; |
480 |
| - // Iterate on each function argument and check it it was marked with a |
481 |
| - // desired layout. |
482 |
| - for (const auto &it : |
483 |
| - llvm::enumerate(funcOp.getFunctionType().getInputs())) { |
484 |
| - int argNumber = it.index(); |
485 |
| - Type inputType = it.value(); |
486 |
| - auto memrefType = inputType.dyn_cast<MemRefType>(); |
487 |
| - auto layoutAttr = funcOp.getArgAttrOfType<AffineMapAttr>( |
488 |
| - argNumber, BufferizableOpInterface::kBufferLayoutAttrName); |
489 |
| - AffineMap desiredLayoutMap = |
490 |
| - layoutAttr ? layoutAttr.getValue() : AffineMap(); |
491 |
| - AffineMap currentLayoutMap = |
492 |
| - memrefType ? getStridedLinearLayoutMap(memrefType) : AffineMap(); |
493 |
| - if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) { |
494 |
| - argumentTypes.push_back(inputType); |
495 |
| - foreachCaller(callerMap, funcOp, [&](Operation *caller) { |
496 |
| - operandsPerCaller.find(caller)->getSecond().push_back( |
497 |
| - caller->getOperand(argNumber)); |
498 |
| - }); |
499 |
| - continue; |
500 |
| - } |
501 |
| - |
502 |
| - // Compute the buffer type with desired layout and add to input argument |
503 |
| - // types. |
504 |
| - MemRefType desiredMemrefType = MemRefType::get( |
505 |
| - memrefType.getShape(), memrefType.getElementType(), desiredLayoutMap); |
506 |
| - argumentTypes.push_back(desiredMemrefType); |
507 |
| - |
508 |
| - // If funcOp's body is not empty, change the bbArg type and propagate. |
509 |
| - if (!funcOp.getBody().empty()) { |
510 |
| - BlockArgument bbArg = funcOp.getArgument(argNumber); |
511 |
| - bbArg.setType(desiredMemrefType); |
512 |
| - OpBuilder b(bbArg.getContext()); |
513 |
| - b.setInsertionPointToStart(bbArg.getOwner()); |
514 |
| - assert(memref::CastOp::areCastCompatible(bbArg.getType(), memrefType) && |
515 |
| - "layoutPostProcessing: cast incompatible"); |
516 |
| - // Cast back to the original memrefType and let it canonicalize. |
517 |
| - Value cast = |
518 |
| - b.create<memref::CastOp>(funcOp.getLoc(), memrefType, bbArg); |
519 |
| - bbArg.replaceAllUsesExcept(cast, cast.getDefiningOp()); |
520 |
| - } |
521 |
| - |
522 |
| - // Cast to desired buffer type on all callers to `funcOp`. |
523 |
| - // TODO: on the callee side, this may even have to trigger a copy to |
524 |
| - // change the layout. For now let the memref::CastOp fail to verify in |
525 |
| - // such cases. |
526 |
| - auto castArg = [&](Operation *caller) { |
527 |
| - OpBuilder b(caller); |
528 |
| - assert( |
529 |
| - memref::CastOp::areCastCompatible( |
530 |
| - caller->getOperand(argNumber).getType(), desiredMemrefType) && |
531 |
| - "layoutPostProcessing.2: cast incompatible"); |
532 |
| - Value newOperand = b.create<memref::CastOp>( |
533 |
| - funcOp.getLoc(), desiredMemrefType, caller->getOperand(argNumber)); |
534 |
| - operandsPerCaller.find(caller)->getSecond().push_back(newOperand); |
535 |
| - }; |
536 |
| - foreachCaller(callerMap, funcOp, castArg); |
537 |
| - } |
538 |
| - |
539 |
| - // Set operands with cast buffer on all callers to `funcOp`. |
540 |
| - foreachCaller(callerMap, funcOp, [&](Operation *caller) { |
541 |
| - caller->setOperands(operandsPerCaller.lookup(caller)); |
542 |
| - }); |
543 |
| - |
544 |
| - // Finally set the funcOp type to update the arguments. |
545 |
| - auto newFuncType = FunctionType::get(moduleOp.getContext(), argumentTypes, |
546 |
| - funcOp.getFunctionType().getResults()); |
547 |
| - funcOp.setType(newFuncType); |
548 |
| - } |
549 |
| -} |
550 |
| - |
551 | 467 | namespace mlir {
|
552 | 468 | namespace linalg {
|
553 | 469 | namespace comprehensive_bufferize {
|
@@ -1111,10 +1027,6 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runModuleBufferize(
|
1111 | 1027 | if (failed(finalizeBuffers(moduleOp, options)))
|
1112 | 1028 | return failure();
|
1113 | 1029 |
|
1114 |
| - // Perform a post-processing pass of layout modification at function boundary |
1115 |
| - // according to the kBufferLayoutAttrName. |
1116 |
| - layoutPostProcessing(moduleOp); |
1117 |
| - |
1118 | 1030 | // Post-pass cleanup of inplaceable and buffer_layout attributes.
|
1119 | 1031 | moduleOp.walk([&](func::FuncOp op) {
|
1120 | 1032 | for (BlockArgument bbArg : op.getArguments())
|
|
0 commit comments