Skip to content

[mlir] Generalize OneShotModuleBufferize to operate on any Operation #148327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,32 +1,32 @@
//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===//
//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H

namespace llvm {
struct LogicalResult;
} // namespace llvm

namespace mlir {
class ModuleOp;
class Operation;

namespace bufferization {
struct BufferizationStatistics;
class OneShotAnalysisState;
struct OneShotBufferizationOptions;
class BufferizationState;

/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in
/// `state`.
llvm::LogicalResult
analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
BufferizationStatistics *statistics = nullptr);
analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state,
BufferizationStatistics *statistics = nullptr);

/// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`.
///
Expand All @@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
/// is not empty. The FuncOps it contains were not analyzed. Buffer copies
/// will be inserted only to these FuncOps.
llvm::LogicalResult
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
BufferizationState &state,
BufferizationStatistics *statistics = nullptr);
bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state,
BufferizationStatistics *statistics = nullptr);

/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
void removeBufferizationAttributesInModule(ModuleOp moduleOp);
/// Remove bufferization attributes on every FuncOp arguments in the RootOp.
void removeBufferizationAttributesInRoot(Operation *rootOp);

/// Run One-Shot Module Bufferization on the given module. Performs a simple
/// Run One-Shot Root Bufferization on the given root op. Performs a simple
/// function call analysis to determine which function arguments are
/// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot
/// Bufferize.
llvm::LogicalResult runOneShotModuleBufferize(
ModuleOp moduleOp,
llvm::LogicalResult runOneShotRootBufferize(
Operation *rootOp,
const bufferization::OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics = nullptr);

} // namespace bufferization
} // namespace mlir

#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H
#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Passes.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
Expand Down Expand Up @@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter,
if (options.bufferizeFunctionBoundaries) {
if (!moduleOp)
return emitSilenceableError() << "expected module target";
if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options,
bufferizationState)))
if (failed(bufferization::runOneShotRootBufferize(moduleOp, options,
bufferizationState)))
return emitSilenceableError() << "bufferization failed";
} else {
if (failed(bufferization::runOneShotBufferize(target, options,
Expand Down
5 changes: 2 additions & 3 deletions mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h"
#include "mlir/IR/Diagnostics.h"
Expand Down Expand Up @@ -163,8 +163,7 @@ struct OneShotBufferizePass
BufferizationStatistics statistics;
ModuleOp moduleOp = getOperation();
if (opt.bufferizeFunctionBoundaries) {
if (failed(
runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) {
if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) {
signalPassFailure();
return;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms
FuncBufferizableOpInterfaceImpl.cpp
LowerDeallocations.cpp
OneShotAnalysis.cpp
OneShotModuleBufferize.cpp
OneShotRootBufferize.cpp
OwnershipBasedBufferDeallocation.cpp
TensorCopyInsertion.cpp
OptimizeAllocationLiveness.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/Dominance.h"
Expand Down Expand Up @@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter,
OneShotAnalysisState state(op, options);
if (moduleOp) {
// Module analysis takes into account function boundaries.
if (failed(analyzeModuleOp(moduleOp, state)))
if (failed(analyzeRootOp(moduleOp, state)))
return failure();
} else {
// Regular One-Shot Bufferize ignores func.func block arguments, func.call,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===//
//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries
//----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Module Bufferization is an extension of One-Shot Bufferize that
// Root Bufferization is an extension of One-Shot Bufferize that
// bufferizes function boundaries. It provides `BufferizableOpInterface`
// implementations for FuncOp, CallOp and ReturnOp.
//
// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`.
// This function analyzes the given module and determines the order of analysis
// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`.
// This function analyzes the given op and determines the order of analysis
// and bufferization: Functions that are called are processed before their
// respective callers.
//
Expand All @@ -24,7 +25,7 @@
// * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is
// read/written.
//
// Module Bufferization implements the following calling convention.
// Root Bufferization implements the following calling convention.
//
// * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always
// be written to in-place.
Expand Down Expand Up @@ -57,7 +58,7 @@
// TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked
// as "not reading" and/or "not writing".

#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"

#include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h"
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
Expand Down Expand Up @@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
llvm::IsaPred<TensorType>);
}

/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by
/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by
/// callee-caller order (i.e., callees without callers first). Store all
/// remaining functions (i.e., the ones that call each other recursively) in
/// `remainingFuncOps`. Does not traverse nested symbol tables.
Expand All @@ -309,34 +310,37 @@ static bool hasTensorSignature(func::FuncOp funcOp) {
/// Return `failure()` if we are unable to retrieve the called FuncOp from
/// any func::CallOp.
static LogicalResult getFuncOpsOrderedByCalls(
ModuleOp moduleOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
Operation *rootOp, SmallVectorImpl<func::FuncOp> &orderedFuncOps,
SmallVectorImpl<func::FuncOp> &remainingFuncOps, FuncCallerMap &callerMap,
SymbolTableCollection &symbolTables) {
// For each FuncOp, the set of functions called by it (i.e. the union of
// symbols of all nested func::CallOp).
DenseMap<func::FuncOp, DenseSet<func::FuncOp>> calledBy;
// For each FuncOp, the number of func::CallOp it contains.
DenseMap<func::FuncOp, unsigned> numberCallOpsContainedInFuncOp;

for (func::FuncOp funcOp : moduleOp.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature, then
// it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
return WalkResult::skip();

callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
for (mlir::Region &region : rootOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
// Collect function calls and populate the caller map.
numberCallOpsContainedInFuncOp[funcOp] = 0;
WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult {
func::FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called func::FuncOp");
// If the called function does not have any tensors in its signature,
// then it is not necessary to bufferize the callee before the caller.
if (!hasTensorSignature(calledFunction))
return WalkResult::skip();

callerMap[calledFunction].insert(callOp);
if (calledBy[calledFunction].insert(funcOp).second) {
numberCallOpsContainedInFuncOp[funcOp]++;
}
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
}
return WalkResult::advance();
});
if (res.wasInterrupted())
return failure();
}
}

// Iteratively remove function operations that do not call any of the
Expand Down Expand Up @@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) {
}

LogicalResult
mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
mlir::bufferization::analyzeRootOp(Operation *rootOp,
OneShotAnalysisState &state,
BufferizationStatistics *statistics) {
assert(state.getOptions().bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state);
Expand All @@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
// A mapping of FuncOps to their callers.
FuncCallerMap callerMap;

if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap,
funcState.symbolTables)))
if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
callerMap, funcState.symbolTables)))
return failure();

// Analyze functions in order. Starting with functions that are not calling
Expand Down Expand Up @@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp,
return success();
}

void mlir::bufferization::removeBufferizationAttributesInModule(
ModuleOp moduleOp) {
for (auto op : moduleOp.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : op.getArguments())
removeBufferizationAttributes(bbArg);
void mlir::bufferization::removeBufferizationAttributesInRoot(
Operation *rootOp) {
for (mlir::Region &region : rootOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (func::FuncOp funcOp : block.getOps<func::FuncOp>()) {
for (BlockArgument bbArg : funcOp.getArguments())
removeBufferizationAttributes(bbArg);
}
}
}
}

LogicalResult mlir::bufferization::bufferizeModuleOp(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
LogicalResult mlir::bufferization::bufferizeRootOp(
Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
IRRewriter rewriter(moduleOp.getContext());
IRRewriter rewriter(rootOp->getContext());

// A list of non-circular functions in the order in which they are analyzed
// and bufferized.
Expand All @@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
// accurate buffer types for function return values. Functions that call
// each other recursively are bufferized in an unspecified order at the end.
// We may use unnecessarily "complex" (in terms of layout map) buffer types.
if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps,
remainingFuncOps, callerMap,
state.getSymbolTables())))
if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps,
callerMap, state.getSymbolTables())))
return failure();
llvm::append_range(orderedFuncOps, remainingFuncOps);

Expand All @@ -571,30 +577,35 @@ LogicalResult mlir::bufferization::bufferizeModuleOp(
}

// Bufferize all other ops.
for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, state, statistics)))
return failure();
for (mlir::Region &region : rootOp->getRegions()) {
for (mlir::Block &block : region.getBlocks()) {
for (mlir::Operation &op :
llvm::make_early_inc_range(block.getOperations())) {
// Functions were already bufferized.
if (isa<func::FuncOp>(&op) || op.hasTrait<OpTrait::SymbolTable>())
continue;
if (failed(bufferizeOp(&op, options, state, statistics)))
return failure();
}
}
}

// Post-pass cleanup of function argument attributes.
removeBufferizationAttributesInModule(moduleOp);
removeBufferizationAttributesInRoot(rootOp);

return success();
}

LogicalResult mlir::bufferization::runOneShotModuleBufferize(
ModuleOp moduleOp, const OneShotBufferizationOptions &options,
LogicalResult mlir::bufferization::runOneShotRootBufferize(
Operation *rootOp, const OneShotBufferizationOptions &options,
BufferizationState &state, BufferizationStatistics *statistics) {
assert(options.bufferizeFunctionBoundaries &&
"expected that function boundary bufferization is activated");
assert(!(options.copyBeforeWrite && options.testAnalysisOnly) &&
"invalid combination of bufferization flags");
if (!options.copyBeforeWrite) {
if (options.noAnalysisFuncFilter.empty()) {
if (failed(insertTensorCopies(moduleOp, options, state, statistics)))
if (failed(insertTensorCopies(rootOp, options, state, statistics)))
return failure();
} else {
// FuncOps whose names are specified in options.noAnalysisFuncFilter will
Expand All @@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize(
};
OneShotBufferizationOptions updatedOptions(options);
updatedOptions.opFilter.denyOperation(analysisFilterFn);
if (failed(
insertTensorCopies(moduleOp, updatedOptions, state, statistics)))
if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics)))
return failure();
}
}
if (options.testAnalysisOnly)
return success();
if (failed(bufferizeModuleOp(moduleOp, options, state, statistics)))
if (failed(bufferizeRootOp(moduleOp, options, state, statistics)))
return failure();
return success();
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h"
#include "mlir/Dialect/Bufferization/Transforms/Transforms.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"

Expand All @@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies(
// analysis depending on whether function boundary bufferization is enabled or
// not.
if (options.bufferizeFunctionBoundaries) {
if (failed(analyzeModuleOp(cast<ModuleOp>(op), analysisState, statistics)))
if (failed(analyzeRootOp(op, analysisState, statistics)))
return failure();
} else {
if (failed(analyzeOp(op, analysisState, statistics)))
Expand Down
Loading
Loading