1
- // ===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
1
+ // ===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
2
+ // ----===//
2
3
//
3
4
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4
5
// See https://llvm.org/LICENSE.txt for license information.
5
6
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6
7
//
7
8
// ===----------------------------------------------------------------------===//
8
9
//
9
- // Module Bufferization is an extension of One-Shot Bufferize that
10
+ // Root Bufferization is an extension of One-Shot Bufferize that
10
11
// bufferizes function boundaries. It provides `BufferizableOpInterface`
11
12
// implementations for FuncOp, CallOp and ReturnOp.
12
13
//
13
- // Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp , ...)`.
14
- // This function analyzes the given module and determines the order of analysis
14
+ // Root Bufferization is run via `runOneShotRootBufferize(RootOp , ...)`.
15
+ // This function analyzes the given op and determines the order of analysis
15
16
// and bufferization: Functions that are called are processed before their
16
17
// respective callers.
17
18
//
24
25
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
25
26
// read/written.
26
27
//
27
- // Module Bufferization implements the following calling convention.
28
+ // Root Bufferization implements the following calling convention.
28
29
//
29
30
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
30
31
// be written to in-place.
57
58
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
58
59
// as "not reading" and/or "not writing".
59
60
60
- #include " mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize .h"
61
+ #include " mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize .h"
61
62
62
63
#include " mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
63
64
#include " mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
299
300
llvm::IsaPred<TensorType>);
300
301
}
301
302
302
- // / Store all functions of the `moduleOp ` in `orderedFuncOps`, sorted by
303
+ // / Store all functions of the `rootOp ` in `orderedFuncOps`, sorted by
303
304
// / callee-caller order (i.e., callees without callers first). Store all
304
305
// / remaining functions (i.e., the ones that call each other recursively) in
305
306
// / `remainingFuncOps`. Does not traverse nested symbol tables.
@@ -309,34 +310,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
309
310
// / Return `failure()` if we are unable to retrieve the called FuncOp from
310
311
// / any func::CallOp.
311
312
static LogicalResult getFuncOpsOrderedByCalls (
312
- ModuleOp moduleOp , SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313
+ Operation *rootOp , SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313
314
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314
315
SymbolTableCollection &symbolTables) {
315
316
// For each FuncOp, the set of functions called by it (i.e. the union of
316
317
// symbols of all nested func::CallOp).
317
318
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318
319
// For each FuncOp, the number of func::CallOp it contains.
319
320
DenseMap<func::FuncOp, unsigned > numberCallOpsContainedInFuncOp;
320
-
321
- for (func::FuncOp funcOp : moduleOp.getOps <func::FuncOp>()) {
322
- // Collect function calls and populate the caller map.
323
- numberCallOpsContainedInFuncOp[funcOp] = 0 ;
324
- WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
325
- func::FuncOp calledFunction = getCalledFunction (callOp, symbolTables);
326
- assert (calledFunction && " could not retrieved called func::FuncOp" );
327
- // If the called function does not have any tensors in its signature, then
328
- // it is not necessary to bufferize the callee before the caller.
329
- if (!hasTensorSignature (calledFunction))
330
- return WalkResult::skip ();
331
-
332
- callerMap[calledFunction].insert (callOp);
333
- if (calledBy[calledFunction].insert (funcOp).second ) {
334
- numberCallOpsContainedInFuncOp[funcOp]++;
321
+ for (mlir::Region ®ion : rootOp->getRegions ()) {
322
+ for (mlir::Block &block : region.getBlocks ()) {
323
+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
324
+ // Collect function calls and populate the caller map.
325
+ numberCallOpsContainedInFuncOp[funcOp] = 0 ;
326
+ WalkResult res = funcOp.walk ([&](func::CallOp callOp) -> WalkResult {
327
+ func::FuncOp calledFunction = getCalledFunction (callOp);
328
+ assert (calledFunction && " could not retrieved called func::FuncOp" );
329
+ // If the called function does not have any tensors in its signature,
330
+ // then it is not necessary to bufferize the callee before the caller.
331
+ if (!hasTensorSignature (calledFunction))
332
+ return WalkResult::skip ();
333
+
334
+ callerMap[calledFunction].insert (callOp);
335
+ if (calledBy[calledFunction].insert (funcOp).second ) {
336
+ numberCallOpsContainedInFuncOp[funcOp]++;
337
+ }
338
+ return WalkResult::advance ();
339
+ });
340
+ if (res.wasInterrupted ())
341
+ return failure ();
335
342
}
336
- return WalkResult::advance ();
337
- });
338
- if (res.wasInterrupted ())
339
- return failure ();
343
+ }
340
344
}
341
345
342
346
// Iteratively remove function operations that do not call any of the
@@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
447
451
}
448
452
449
453
LogicalResult
450
- mlir::bufferization::analyzeModuleOp (ModuleOp moduleOp ,
451
- OneShotAnalysisState &state,
452
- BufferizationStatistics *statistics) {
454
+ mlir::bufferization::analyzeRootOp (Operation *rootOp ,
455
+ OneShotAnalysisState &state,
456
+ BufferizationStatistics *statistics) {
453
457
assert (state.getOptions ().bufferizeFunctionBoundaries &&
454
458
" expected that function boundary bufferization is activated" );
455
459
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState (state);
@@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
465
469
// A mapping of FuncOps to their callers.
466
470
FuncCallerMap callerMap;
467
471
468
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
469
- remainingFuncOps, callerMap,
470
- funcState.symbolTables )))
472
+ if (failed (getFuncOpsOrderedByCalls (rootOp, orderedFuncOps, remainingFuncOps,
473
+ callerMap, funcState.symbolTables )))
471
474
return failure ();
472
475
473
476
// Analyze functions in order. Starting with functions that are not calling
@@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
511
514
return success ();
512
515
}
513
516
514
- void mlir::bufferization::removeBufferizationAttributesInModule (
515
- ModuleOp moduleOp) {
516
- for (auto op : moduleOp.getOps <func::FuncOp>()) {
517
- for (BlockArgument bbArg : op.getArguments ())
518
- removeBufferizationAttributes (bbArg);
517
+ void mlir::bufferization::removeBufferizationAttributesInRoot (
518
+ Operation *rootOp) {
519
+ for (mlir::Region ®ion : rootOp->getRegions ()) {
520
+ for (mlir::Block &block : region.getBlocks ()) {
521
+ for (func::FuncOp funcOp : block.getOps <func::FuncOp>()) {
522
+ for (BlockArgument bbArg : funcOp.getArguments ())
523
+ removeBufferizationAttributes (bbArg);
524
+ }
525
+ }
519
526
}
520
527
}
521
528
522
- LogicalResult mlir::bufferization::bufferizeModuleOp (
523
- ModuleOp moduleOp , const OneShotBufferizationOptions &options,
529
+ LogicalResult mlir::bufferization::bufferizeRootOp (
530
+ Operation *rootOp , const OneShotBufferizationOptions &options,
524
531
BufferizationState &state, BufferizationStatistics *statistics) {
525
532
assert (options.bufferizeFunctionBoundaries &&
526
533
" expected that function boundary bufferization is activated" );
527
- IRRewriter rewriter (moduleOp. getContext ());
534
+ IRRewriter rewriter (rootOp-> getContext ());
528
535
529
536
// A list of non-circular functions in the order in which they are analyzed
530
537
// and bufferized.
@@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
542
549
// accurate buffer types for function return values. Functions that call
543
550
// each other recursively are bufferized in an unspecified order at the end.
544
551
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
545
- if (failed (getFuncOpsOrderedByCalls (moduleOp, orderedFuncOps,
546
- remainingFuncOps, callerMap,
547
- state.getSymbolTables ())))
552
+ if (failed (getFuncOpsOrderedByCalls (rootOp, orderedFuncOps, remainingFuncOps,
553
+ callerMap, state.getSymbolTables ())))
548
554
return failure ();
549
555
llvm::append_range (orderedFuncOps, remainingFuncOps);
550
556
@@ -571,30 +577,35 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571
577
}
572
578
573
579
// Bufferize all other ops.
574
- for (Operation &op : llvm::make_early_inc_range (moduleOp.getOps ())) {
575
- // Functions were already bufferized.
576
- if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
577
- continue ;
578
- if (failed (bufferizeOp (&op, options, state, statistics)))
579
- return failure ();
580
+ for (mlir::Region ®ion : rootOp->getRegions ()) {
581
+ for (mlir::Block &block : region.getBlocks ()) {
582
+ for (mlir::Operation &op :
583
+ llvm::make_early_inc_range (block.getOperations ())) {
584
+ // Functions were already bufferized.
585
+ if (isa<func::FuncOp>(&op) || op.hasTrait <OpTrait::SymbolTable>())
586
+ continue ;
587
+ if (failed (bufferizeOp (&op, options, state, statistics)))
588
+ return failure ();
589
+ }
590
+ }
580
591
}
581
592
582
593
// Post-pass cleanup of function argument attributes.
583
- removeBufferizationAttributesInModule (moduleOp );
594
+ removeBufferizationAttributesInRoot (rootOp );
584
595
585
596
return success ();
586
597
}
587
598
588
- LogicalResult mlir::bufferization::runOneShotModuleBufferize (
589
- ModuleOp moduleOp , const OneShotBufferizationOptions &options,
599
+ LogicalResult mlir::bufferization::runOneShotRootBufferize (
600
+ Operation *rootOp , const OneShotBufferizationOptions &options,
590
601
BufferizationState &state, BufferizationStatistics *statistics) {
591
602
assert (options.bufferizeFunctionBoundaries &&
592
603
" expected that function boundary bufferization is activated" );
593
604
assert (!(options.copyBeforeWrite && options.testAnalysisOnly ) &&
594
605
" invalid combination of bufferization flags" );
595
606
if (!options.copyBeforeWrite ) {
596
607
if (options.noAnalysisFuncFilter .empty ()) {
597
- if (failed (insertTensorCopies (moduleOp , options, state, statistics)))
608
+ if (failed (insertTensorCopies (rootOp , options, state, statistics)))
598
609
return failure ();
599
610
} else {
600
611
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
610
621
};
611
622
OneShotBufferizationOptions updatedOptions (options);
612
623
updatedOptions.opFilter .denyOperation (analysisFilterFn);
613
- if (failed (
614
- insertTensorCopies (moduleOp, updatedOptions, state, statistics)))
624
+ if (failed (insertTensorCopies (rootOp, updatedOptions, state, statistics)))
615
625
return failure ();
616
626
}
617
627
}
618
628
if (options.testAnalysisOnly )
619
629
return success ();
620
- if (failed (bufferizeModuleOp (moduleOp, options, state, statistics)))
630
+ if (failed (bufferizeRootOp (moduleOp, options, state, statistics)))
621
631
return failure ();
622
632
return success ();
623
633
}
0 commit comments