@@ -324,8 +324,10 @@ setInPlaceFuncArgument(BlockArgument bbArg,
324
324
325
325
// / Remove the attribute that triggers inplace bufferization on a FuncOp
326
326
// / argument `bbArg`.
327
- static void removeInPlaceFuncArgument (BlockArgument bbArg) {
327
+ static void removeBufferizationFuncArguments (BlockArgument bbArg) {
328
328
auto funcOp = cast<FuncOp>(bbArg.getOwner ()->getParentOp ());
329
+ funcOp.removeArgAttr (bbArg.getArgNumber (),
330
+ LinalgDialect::kBufferLayoutAttrName );
329
331
funcOp.removeArgAttr (bbArg.getArgNumber (),
330
332
LinalgDialect::kInplaceableAttrName );
331
333
}
@@ -2608,6 +2610,96 @@ static void applyEnablingTransformations(ModuleOp moduleOp) {
2608
2610
(void )applyPatternsAndFoldGreedily (moduleOp, std::move (patterns));
2609
2611
}
2610
2612
2613
+ static void
2614
+ foreachCaller (const DenseMap<FuncOp, DenseSet<Operation *>> &callerMap,
2615
+ FuncOp callee, llvm::function_ref<void (Operation *)> doit) {
2616
+ auto itCallers = callerMap.find (callee);
2617
+ if (itCallers == callerMap.end ())
2618
+ return ;
2619
+ for (Operation *caller : itCallers->second )
2620
+ doit (caller);
2621
+ }
2622
+
2623
+ // / Postprocess the linalg.buffer_layout annotation across function boundaries.
2624
+ // / This is a purely mechanical process that may later become part of a
2625
+ // / separate pass with its own layout assignment heuristic.
2626
+ static void layoutPostProcessing (ModuleOp moduleOp) {
2627
+ SmallVector<FuncOp> orderedFuncOps;
2628
+ DenseMap<FuncOp, DenseSet<Operation *>> callerMap;
2629
+ auto res = getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps, callerMap);
2630
+ assert (succeeded (res) && " unexpected getFuncOpsOrderedByCalls failure" );
2631
+
2632
+ for (FuncOp funcOp : orderedFuncOps) {
2633
+ DenseMap<Operation *, SmallVector<Value>> operandsPerCaller;
2634
+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2635
+ operandsPerCaller.try_emplace (caller, SmallVector<Value>());
2636
+ });
2637
+
2638
+ SmallVector<Type> argumentTypes;
2639
+ // Iterate on each function argument and check it it was marked with a
2640
+ // desired layout.
2641
+ for (auto it : llvm::enumerate (funcOp.getType ().getInputs ())) {
2642
+ int argNumber = it.index ();
2643
+ Type inputType = it.value ();
2644
+ auto memrefType = inputType.dyn_cast <MemRefType>();
2645
+ auto layoutAttr = funcOp.getArgAttrOfType <AffineMapAttr>(
2646
+ argNumber, LinalgDialect::kBufferLayoutAttrName );
2647
+ AffineMap desiredLayoutMap =
2648
+ layoutAttr ? layoutAttr.getValue () : AffineMap ();
2649
+ AffineMap currentLayoutMap =
2650
+ memrefType ? getStridedLinearLayoutMap (memrefType) : AffineMap ();
2651
+ if (!memrefType || !layoutAttr || desiredLayoutMap == currentLayoutMap) {
2652
+ argumentTypes.push_back (inputType);
2653
+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2654
+ operandsPerCaller.find (caller)->getSecond ().push_back (
2655
+ caller->getOperand (argNumber));
2656
+ });
2657
+ continue ;
2658
+ }
2659
+
2660
+ // Compute the buffer type with desired layout and add to input argument
2661
+ // types.
2662
+ MemRefType desiredMemrefType = MemRefType::get (
2663
+ memrefType.getShape (), memrefType.getElementType (), desiredLayoutMap);
2664
+ argumentTypes.push_back (desiredMemrefType);
2665
+
2666
+ // If funcOp's body is not empty, change the bbArg type and propagate.
2667
+ if (!funcOp.body ().empty ()) {
2668
+ BlockArgument bbArg = funcOp.getArgument (argNumber);
2669
+ bbArg.setType (desiredMemrefType);
2670
+ OpBuilder b (bbArg.getContext ());
2671
+ b.setInsertionPointToStart (bbArg.getOwner ());
2672
+ // Cast back to the original memrefType and let it canonicalize.
2673
+ Value cast =
2674
+ b.create <memref::CastOp>(funcOp.getLoc (), memrefType, bbArg);
2675
+ bbArg.replaceAllUsesExcept (cast, cast.getDefiningOp ());
2676
+ }
2677
+
2678
+ // Cast to desired buffer type on all callers to `funcOp`.
2679
+ // TODO: on the callee side, this may even have to trigger a copy to
2680
+ // change the layout. For now let the memref::CastOp fail to verify in
2681
+ // such cases.
2682
+ auto castArg = [&](Operation *caller) {
2683
+ OpBuilder b (caller);
2684
+ Value newOperand = b.create <memref::CastOp>(
2685
+ funcOp.getLoc (), desiredMemrefType, caller->getOperand (argNumber));
2686
+ operandsPerCaller.find (caller)->getSecond ().push_back (newOperand);
2687
+ };
2688
+ foreachCaller (callerMap, funcOp, castArg);
2689
+ }
2690
+
2691
+ // Set operands with cast buffer on all callers to `funcOp`.
2692
+ foreachCaller (callerMap, funcOp, [&](Operation *caller) {
2693
+ caller->setOperands (operandsPerCaller.lookup (caller));
2694
+ });
2695
+
2696
+ // Finally set the funcOp type to update the arguments.
2697
+ auto newFuncType = FunctionType::get (moduleOp.getContext (), argumentTypes,
2698
+ funcOp.getType ().getResults ());
2699
+ funcOp.setType (newFuncType);
2700
+ }
2701
+ }
2702
+
2611
2703
void LinalgComprehensiveModuleBufferize::runOnOperation () {
2612
2704
ModuleOp moduleOp = getOperation ();
2613
2705
applyEnablingTransformations (moduleOp);
@@ -2672,12 +2764,16 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
2672
2764
}
2673
2765
}
2674
2766
2675
- // Post-pass cleanup of inplaceable attributes.
2767
+ // Perform a post-processing pass of layout modification at function boundary
2768
+ // according to the kBufferLayoutAttrName.
2769
+ layoutPostProcessing (moduleOp);
2770
+
2771
+ // Post-pass cleanup of inplaceable and buffer_layout attributes.
2676
2772
moduleOp.walk (
2677
2773
[&](Operation *op) { op->removeAttr (kInPlaceResultsAttrName ); });
2678
2774
moduleOp.walk ([&](FuncOp op) {
2679
2775
for (BlockArgument bbArg : op.getArguments ())
2680
- removeInPlaceFuncArgument (bbArg);
2776
+ removeBufferizationFuncArguments (bbArg);
2681
2777
});
2682
2778
2683
2779
OpPassManager cleanupPipeline (OpPassManager (" module" ));
0 commit comments