Skip to content

Commit e35c80c

Browse files
committed
[mlir] Generalize OneShotModuleBufferize to operate on any Operation
1 parent aec3016 commit e35c80c

19 files changed

+207
-89
lines changed
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,32 @@
1-
//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===//
1+
//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
10-
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
9+
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
10+
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
1111

1212
namespace llvm {
1313
struct LogicalResult;
1414
} // namespace llvm
1515

1616
namespace mlir {
17-
class ModuleOp;
17+
class Operation;
1818

1919
namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
2323
class BufferizationState;
2424

25-
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
25+
/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in
2626
/// `state`.
2727
llvm::LogicalResult
28-
analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
29-
BufferizationStatistics *statistics = nullptr);
28+
analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state,
29+
BufferizationStatistics *statistics = nullptr);
3030

3131
/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
3232
///
@@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3838
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
3939
/// will be inserted only to these FuncOps.
4040
llvm::LogicalResult
41-
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42-
BufferizationState &state,
43-
BufferizationStatistics *statistics = nullptr);
41+
bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
43+
BufferizationStatistics *statistics = nullptr);
4444

45-
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
46-
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
45+
/// Remove bufferization attributes on every FuncOp arguments in the RootOp.
46+
void removeBufferizationAttributesInRoot(Operation *rootOp);
4747

48-
/// Run One-Shot Module Bufferization on the given module. Performs a simple
48+
/// Run One-Shot Root Bufferization on the given root op. Performs a simple
4949
/// function call analysis to determine which function arguments are
5050
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
5151
/// Bufferize.
52-
llvm::LogicalResult runOneShotModuleBufferize(
53-
ModuleOp moduleOp,
52+
llvm::LogicalResult runOneShotRootBufferize(
53+
Operation *rootOp,
5454
const bufferization::OneShotBufferizationOptions &options,
5555
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5656

5757
} // namespace bufferization
5858
} // namespace mlir
5959

60-
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
60+
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H

mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1212
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
13-
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
13+
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
1515
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1616
#include "mlir/Dialect/Linalg/IR/Linalg.h"
@@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
9292
if (options.bufferizeFunctionBoundaries) {
9393
if (!moduleOp)
9494
return emitSilenceableError() << "expected module target";
95-
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
96-
bufferizationState)))
95+
if (failed(bufferization::runOneShotRootBufferize(moduleOp, options,
96+
bufferizationState)))
9797
return emitSilenceableError() << "bufferization failed";
9898
} else {
9999
if (failed(bufferization::runOneShotBufferize(target, options,

mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15-
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
1616
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1717
#include "mlir/Dialect/MemRef/IR/MemRef.h"
1818
#include "mlir/IR/Diagnostics.h"
@@ -163,8 +163,7 @@ struct OneShotBufferizePass
163163
BufferizationStatistics statistics;
164164
ModuleOp moduleOp = getOperation();
165165
if (opt.bufferizeFunctionBoundaries) {
166-
if (failed(
167-
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
166+
if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) {
168167
signalPassFailure();
169168
return;
170169
}

mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
1111
FuncBufferizableOpInterfaceImpl.cpp
1212
LowerDeallocations.cpp
1313
OneShotAnalysis.cpp
14-
OneShotModuleBufferize.cpp
14+
OneShotRootBufferize.cpp
1515
OwnershipBasedBufferDeallocation.cpp
1616
TensorCopyInsertion.cpp
1717
OptimizeAllocationLiveness.cpp

mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
14-
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
14+
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
1515
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1616
#include "mlir/Dialect/Tensor/IR/Tensor.h"
1717
#include "mlir/IR/Dominance.h"
@@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
209209
OneShotAnalysisState state(op, options);
210210
if (moduleOp) {
211211
// Module analysis takes into account function boundaries.
212-
if (failed(analyzeModuleOp(moduleOp, state)))
212+
if (failed(analyzeRootOp(moduleOp, state)))
213213
return failure();
214214
} else {
215215
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,

mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp renamed to mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp

Lines changed: 67 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,18 @@
1-
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
1+
//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
2+
//----===//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
45
// See https://llvm.org/LICENSE.txt for license information.
56
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
67
//
78
//===----------------------------------------------------------------------===//
89
//
9-
// Module Bufferization is an extension of One-Shot Bufferize that
10+
// Root Bufferization is an extension of One-Shot Bufferize that
1011
// bufferizes function boundaries. It provides `BufferizableOpInterface`
1112
// implementations for FuncOp, CallOp and ReturnOp.
1213
//
13-
// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
14-
// This function analyzes the given module and determines the order of analysis
14+
// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`.
15+
// This function analyzes the given op and determines the order of analysis
1516
// and bufferization: Functions that are called are processed before their
1617
// respective callers.
1718
//
@@ -24,7 +25,7 @@
2425
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
2526
// read/written.
2627
//
27-
// Module Bufferization implements the following calling convention.
28+
// Root Bufferization implements the following calling convention.
2829
//
2930
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
3031
// be written to in-place.
@@ -57,7 +58,7 @@
5758
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
5859
// as "not reading" and/or "not writing".
5960

60-
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
61+
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
6162

6263
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
6364
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
@@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
299300
llvm::IsaPred<TensorType>);
300301
}
301302

302-
/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
303+
/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by
303304
/// callee-caller order (i.e., callees without callers first). Store all
304305
/// remaining functions (i.e., the ones that call each other recursively) in
305306
/// `remainingFuncOps`. Does not traverse nested symbol tables.
@@ -309,34 +310,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
309310
/// Return `failure()` if we are unable to retrieve the called FuncOp from
310311
/// any func::CallOp.
311312
static LogicalResult getFuncOpsOrderedByCalls(
312-
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313+
Operation *rootOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
313314
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
314315
SymbolTableCollection &symbolTables) {
315316
// For each FuncOp, the set of functions called by it (i.e. the union of
316317
// symbols of all nested func::CallOp).
317318
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
318319
// For each FuncOp, the number of func::CallOp it contains.
319320
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;
320-
321-
for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
322-
// Collect function calls and populate the caller map.
323-
numberCallOpsContainedInFuncOp[funcOp] = 0;
324-
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
325-
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
326-
assert(calledFunction && "could not retrieved called func::FuncOp");
327-
// If the called function does not have any tensors in its signature, then
328-
// it is not necessary to bufferize the callee before the caller.
329-
if (!hasTensorSignature(calledFunction))
330-
return WalkResult::skip();
331-
332-
callerMap[calledFunction].insert(callOp);
333-
if (calledBy[calledFunction].insert(funcOp).second) {
334-
numberCallOpsContainedInFuncOp[funcOp]++;
321+
for (mlir::Region &region : rootOp->getRegions()) {
322+
for (mlir::Block &block : region.getBlocks()) {
323+
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
324+
// Collect function calls and populate the caller map.
325+
numberCallOpsContainedInFuncOp[funcOp] = 0;
326+
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
327+
func::FuncOp calledFunction = getCalledFunction(callOp);
328+
assert(calledFunction && "could not retrieved called func::FuncOp");
329+
// If the called function does not have any tensors in its signature,
330+
// then it is not necessary to bufferize the callee before the caller.
331+
if (!hasTensorSignature(calledFunction))
332+
return WalkResult::skip();
333+
334+
callerMap[calledFunction].insert(callOp);
335+
if (calledBy[calledFunction].insert(funcOp).second) {
336+
numberCallOpsContainedInFuncOp[funcOp]++;
337+
}
338+
return WalkResult::advance();
339+
});
340+
if (res.wasInterrupted())
341+
return failure();
335342
}
336-
return WalkResult::advance();
337-
});
338-
if (res.wasInterrupted())
339-
return failure();
343+
}
340344
}
341345

342346
// Iteratively remove function operations that do not call any of the
@@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
447451
}
448452

449453
LogicalResult
450-
mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
451-
OneShotAnalysisState &state,
452-
BufferizationStatistics *statistics) {
454+
mlir::bufferization::analyzeRootOp(Operation *rootOp,
455+
OneShotAnalysisState &state,
456+
BufferizationStatistics *statistics) {
453457
assert(state.getOptions().bufferizeFunctionBoundaries &&
454458
"expected that function boundary bufferization is activated");
455459
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
@@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
465469
// A mapping of FuncOps to their callers.
466470
FuncCallerMap callerMap;
467471

468-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
469-
remainingFuncOps, callerMap,
470-
funcState.symbolTables)))
472+
if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
473+
callerMap, funcState.symbolTables)))
471474
return failure();
472475

473476
// Analyze functions in order. Starting with functions that are not calling
@@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
511514
return success();
512515
}
513516

514-
void mlir::bufferization::removeBufferizationAttributesInModule(
515-
ModuleOp moduleOp) {
516-
for (auto op : moduleOp.getOps<func::FuncOp>()) {
517-
for (BlockArgument bbArg : op.getArguments())
518-
removeBufferizationAttributes(bbArg);
517+
void mlir::bufferization::removeBufferizationAttributesInRoot(
518+
Operation *rootOp) {
519+
for (mlir::Region &region : rootOp->getRegions()) {
520+
for (mlir::Block &block : region.getBlocks()) {
521+
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
522+
for (BlockArgument bbArg : funcOp.getArguments())
523+
removeBufferizationAttributes(bbArg);
524+
}
525+
}
519526
}
520527
}
521528

522-
LogicalResult mlir::bufferization::bufferizeModuleOp(
523-
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
529+
LogicalResult mlir::bufferization::bufferizeRootOp(
530+
Operation *rootOp, const OneShotBufferizationOptions &options,
524531
BufferizationState &state, BufferizationStatistics *statistics) {
525532
assert(options.bufferizeFunctionBoundaries &&
526533
"expected that function boundary bufferization is activated");
527-
IRRewriter rewriter(moduleOp.getContext());
534+
IRRewriter rewriter(rootOp->getContext());
528535

529536
// A list of non-circular functions in the order in which they are analyzed
530537
// and bufferized.
@@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
542549
// accurate buffer types for function return values. Functions that call
543550
// each other recursively are bufferized in an unspecified order at the end.
544551
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
545-
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
546-
remainingFuncOps, callerMap,
547-
state.getSymbolTables())))
552+
if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
553+
callerMap, state.getSymbolTables())))
548554
return failure();
549555
llvm::append_range(orderedFuncOps, remainingFuncOps);
550556

@@ -571,30 +577,35 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
571577
}
572578

573579
// Bufferize all other ops.
574-
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
575-
// Functions were already bufferized.
576-
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
577-
continue;
578-
if (failed(bufferizeOp(&op, options, state, statistics)))
579-
return failure();
580+
for (mlir::Region &region : rootOp->getRegions()) {
581+
for (mlir::Block &block : region.getBlocks()) {
582+
for (mlir::Operation &op :
583+
llvm::make_early_inc_range(block.getOperations())) {
584+
// Functions were already bufferized.
585+
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
586+
continue;
587+
if (failed(bufferizeOp(&op, options, state, statistics)))
588+
return failure();
589+
}
590+
}
580591
}
581592

582593
// Post-pass cleanup of function argument attributes.
583-
removeBufferizationAttributesInModule(moduleOp);
594+
removeBufferizationAttributesInRoot(rootOp);
584595

585596
return success();
586597
}
587598

588-
LogicalResult mlir::bufferization::runOneShotModuleBufferize(
589-
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
599+
LogicalResult mlir::bufferization::runOneShotRootBufferize(
600+
Operation *rootOp, const OneShotBufferizationOptions &options,
590601
BufferizationState &state, BufferizationStatistics *statistics) {
591602
assert(options.bufferizeFunctionBoundaries &&
592603
"expected that function boundary bufferization is activated");
593604
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
594605
"invalid combination of bufferization flags");
595606
if (!options.copyBeforeWrite) {
596607
if (options.noAnalysisFuncFilter.empty()) {
597-
if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
608+
if (failed(insertTensorCopies(rootOp, options, state, statistics)))
598609
return failure();
599610
} else {
600611
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
@@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
610621
};
611622
OneShotBufferizationOptions updatedOptions(options);
612623
updatedOptions.opFilter.denyOperation(analysisFilterFn);
613-
if (failed(
614-
insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
624+
if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics)))
615625
return failure();
616626
}
617627
}
618628
if (options.testAnalysisOnly)
619629
return success();
620-
if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
630+
if (failed(bufferizeRootOp(moduleOp, options, state, statistics)))
621631
return failure();
622632
return success();
623633
}

mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
1313
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
1414
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
15-
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
15+
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
1616
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
1717
#include "mlir/Dialect/Func/IR/FuncOps.h"
1818

@@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
3535
// analysis depending on whether function boundary bufferization is enabled or
3636
// not.
3737
if (options.bufferizeFunctionBoundaries) {
38-
if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
38+
if (failed(analyzeRootOp(op, analysisState, statistics)))
3939
return failure();
4040
} else {
4141
if (failed(analyzeOp(op, analysisState, statistics)))

0 commit comments

Comments
 (0)