diff --git a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h similarity index 64% rename from mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h rename to mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h index 2cf801dd1d951..32b76269e2c03 100644 --- a/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h +++ b/mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h @@ -1,4 +1,4 @@ -//===- OneShotModuleBufferize.h - Bufferization across Func. Boundaries ---===// +//===- OneShotRootBufferize.h - Bufferization across Func. Boundaries ---===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,15 +6,15 @@ // //===----------------------------------------------------------------------===// -#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H -#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H +#ifndef MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H +#define MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H namespace llvm { struct LogicalResult; } // namespace llvm namespace mlir { -class ModuleOp; +class Operation; namespace bufferization { struct BufferizationStatistics; @@ -22,11 +22,11 @@ class OneShotAnalysisState; struct OneShotBufferizationOptions; class BufferizationState; -/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in +/// Analyze `rootOp` and its nested ops. Bufferization decisions are stored in /// `state`. llvm::LogicalResult -analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, - BufferizationStatistics *statistics = nullptr); +analyzeRootOp(Operation *rootOp, OneShotAnalysisState &state, + BufferizationStatistics *statistics = nullptr); /// Bufferize `op` and its nested ops that implement `BufferizableOpInterface`. /// @@ -38,23 +38,23 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state, /// is not empty. The FuncOps it contains were not analyzed. Buffer copies /// will be inserted only to these FuncOps. llvm::LogicalResult -bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options, - BufferizationState &state, - BufferizationStatistics *statistics = nullptr); +bufferizeRootOp(Operation *rootOp, const OneShotBufferizationOptions &options, + BufferizationState &state, + BufferizationStatistics *statistics = nullptr); -/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp. -void removeBufferizationAttributesInModule(ModuleOp moduleOp); +/// Remove bufferization attributes on every FuncOp arguments in the RootOp. +void removeBufferizationAttributesInRoot(Operation *rootOp); -/// Run One-Shot Module Bufferization on the given module. Performs a simple +/// Run One-Shot Root Bufferization on the given root op. Performs a simple /// function call analysis to determine which function arguments are /// inplaceable. Then analyzes and bufferizes FuncOps one-by-one with One-Shot /// Bufferize. -llvm::LogicalResult runOneShotModuleBufferize( - ModuleOp moduleOp, +llvm::LogicalResult runOneShotRootBufferize( + Operation *rootOp, const bufferization::OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics = nullptr); } // namespace bufferization } // namespace mlir -#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTMODULEBUFFERIZE_H +#endif // MLIR_DIALECT_BUFFERIZATION_TRANSFORMS_ONESHOTROOTBUFFERIZE_H diff --git a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp index db1eb20512033..5a52daf6c7698 100644 --- a/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp +++ b/mlir/lib/Dialect/Bufferization/TransformOps/BufferizationTransformOps.cpp @@ -10,7 +10,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" @@ -92,8 +92,8 @@ transform::OneShotBufferizeOp::apply(transform::TransformRewriter &rewriter, if (options.bufferizeFunctionBoundaries) { if (!moduleOp) return emitSilenceableError() << "expected module target"; - if (failed(bufferization::runOneShotModuleBufferize(moduleOp, options, - bufferizationState))) + if (failed(bufferization::runOneShotRootBufferize(moduleOp, options, + bufferizationState))) return emitSilenceableError() << "bufferization failed"; } else { if (failed(bufferization::runOneShotBufferize(target, options, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp index 246555dc8c699..baef091eeebd1 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp @@ -12,7 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/IR/Diagnostics.h" @@ -163,8 +163,7 @@ struct OneShotBufferizePass BufferizationStatistics statistics; ModuleOp moduleOp = getOperation(); if (opt.bufferizeFunctionBoundaries) { - if (failed( - runOneShotModuleBufferize(moduleOp, opt, state, &statistics))) { + if (failed(runOneShotRootBufferize(moduleOp, opt, state, &statistics))) { signalPassFailure(); return; } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt index 7c38621be1bb5..fa310b95df4bd 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/Bufferization/Transforms/CMakeLists.txt @@ -11,7 +11,7 @@ add_mlir_dialect_library(MLIRBufferizationTransforms FuncBufferizableOpInterfaceImpl.cpp LowerDeallocations.cpp OneShotAnalysis.cpp - OneShotModuleBufferize.cpp + OneShotRootBufferize.cpp OwnershipBasedBufferDeallocation.cpp TensorCopyInsertion.cpp OptimizeAllocationLiveness.cpp diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index b7db2e847a335..ee1a9178a9d0d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -11,7 +11,7 @@ #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/IR/Dominance.h" @@ -209,7 +209,7 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(RewriterBase &rewriter, OneShotAnalysisState state(op, options); if (moduleOp) { // Module analysis takes into account function boundaries. - if (failed(analyzeModuleOp(moduleOp, state))) + if (failed(analyzeRootOp(moduleOp, state))) return failure(); } else { // Regular One-Shot Bufferize ignores func.func block arguments, func.call, diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp similarity index 85% rename from mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp rename to mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp index d1d106220a38c..a7865050e6e38 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotRootBufferize.cpp @@ -1,4 +1,5 @@ -//===- ModuleBufferization.cpp - Bufferization across Func. Boundaries ----===// +//===- OneShotRootBufferize.cpp - Bufferization across Func. Boundaries +//----===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. @@ -6,12 +7,12 @@ // //===----------------------------------------------------------------------===// // -// Module Bufferization is an extension of One-Shot Bufferize that +// Root Bufferization is an extension of One-Shot Bufferize that // bufferizes function boundaries. It provides `BufferizableOpInterface` // implementations for FuncOp, CallOp and ReturnOp. // -// Module Bufferization is run via `runOneShotModuleBufferize(ModuleOp, ...)`. -// This function analyzes the given module and determines the order of analysis +// Root Bufferization is run via `runOneShotRootBufferize(RootOp, ...)`. +// This function analyzes the given op and determines the order of analysis // and bufferization: Functions that are called are processed before their // respective callers. // @@ -24,7 +25,7 @@ // * `funcOpBbArgReadWriteAnalysis` determines whether or not a tensor bbArg is // read/written. // -// Module Bufferization implements the following calling convention. +// Root Bufferization implements the following calling convention. // // * In the absence of conflicts within a FuncOp, the FuncOp's bbArgs may always // be written to in-place. @@ -57,7 +58,7 @@ // TODO: Add FuncOp attributes so that bbArgs of external FuncOps can be marked // as "not reading" and/or "not writing". -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h" #include "mlir/Dialect/Bufferization/IR/Bufferization.h" @@ -299,7 +300,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { llvm::IsaPred); } -/// Store all functions of the `moduleOp` in `orderedFuncOps`, sorted by +/// Store all functions of the `rootOp` in `orderedFuncOps`, sorted by /// callee-caller order (i.e., callees without callers first). Store all /// remaining functions (i.e., the ones that call each other recursively) in /// `remainingFuncOps`. Does not traverse nested symbol tables. @@ -309,7 +310,7 @@ static bool hasTensorSignature(func::FuncOp funcOp) { /// Return `failure()` if we are unable to retrieve the called FuncOp from /// any func::CallOp. static LogicalResult getFuncOpsOrderedByCalls( - ModuleOp moduleOp, SmallVectorImpl &orderedFuncOps, + Operation *rootOp, SmallVectorImpl &orderedFuncOps, SmallVectorImpl &remainingFuncOps, FuncCallerMap &callerMap, SymbolTableCollection &symbolTables) { // For each FuncOp, the set of functions called by it (i.e. the union of @@ -317,26 +318,29 @@ static LogicalResult getFuncOpsOrderedByCalls( DenseMap> calledBy; // For each FuncOp, the number of func::CallOp it contains. DenseMap numberCallOpsContainedInFuncOp; - - for (func::FuncOp funcOp : moduleOp.getOps()) { - // Collect function calls and populate the caller map. - numberCallOpsContainedInFuncOp[funcOp] = 0; - WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { - func::FuncOp calledFunction = getCalledFunction(callOp, symbolTables); - assert(calledFunction && "could not retrieved called func::FuncOp"); - // If the called function does not have any tensors in its signature, then - // it is not necessary to bufferize the callee before the caller. - if (!hasTensorSignature(calledFunction)) - return WalkResult::skip(); - - callerMap[calledFunction].insert(callOp); - if (calledBy[calledFunction].insert(funcOp).second) { - numberCallOpsContainedInFuncOp[funcOp]++; + for (mlir::Region ®ion : rootOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps()) { + // Collect function calls and populate the caller map. + numberCallOpsContainedInFuncOp[funcOp] = 0; + WalkResult res = funcOp.walk([&](func::CallOp callOp) -> WalkResult { + func::FuncOp calledFunction = getCalledFunction(callOp); + assert(calledFunction && "could not retrieved called func::FuncOp"); + // If the called function does not have any tensors in its signature, + // then it is not necessary to bufferize the callee before the caller. + if (!hasTensorSignature(calledFunction)) + return WalkResult::skip(); + + callerMap[calledFunction].insert(callOp); + if (calledBy[calledFunction].insert(funcOp).second) { + numberCallOpsContainedInFuncOp[funcOp]++; + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) + return failure(); } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) - return failure(); + } } // Iteratively remove function operations that do not call any of the @@ -447,9 +451,9 @@ static void foldMemRefCasts(func::FuncOp funcOp) { } LogicalResult -mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, - OneShotAnalysisState &state, - BufferizationStatistics *statistics) { +mlir::bufferization::analyzeRootOp(Operation *rootOp, + OneShotAnalysisState &state, + BufferizationStatistics *statistics) { assert(state.getOptions().bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); FuncAnalysisState &funcState = getOrCreateFuncAnalysisState(state); @@ -465,9 +469,8 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, // A mapping of FuncOps to their callers. FuncCallerMap callerMap; - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, - remainingFuncOps, callerMap, - funcState.symbolTables))) + if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps, + callerMap, funcState.symbolTables))) return failure(); // Analyze functions in order. Starting with functions that are not calling @@ -511,20 +514,24 @@ mlir::bufferization::analyzeModuleOp(ModuleOp moduleOp, return success(); } -void mlir::bufferization::removeBufferizationAttributesInModule( - ModuleOp moduleOp) { - for (auto op : moduleOp.getOps()) { - for (BlockArgument bbArg : op.getArguments()) - removeBufferizationAttributes(bbArg); +void mlir::bufferization::removeBufferizationAttributesInRoot( + Operation *rootOp) { + for (mlir::Region ®ion : rootOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (func::FuncOp funcOp : block.getOps()) { + for (BlockArgument bbArg : funcOp.getArguments()) + removeBufferizationAttributes(bbArg); + } + } } } -LogicalResult mlir::bufferization::bufferizeModuleOp( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, +LogicalResult mlir::bufferization::bufferizeRootOp( + Operation *rootOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); - IRRewriter rewriter(moduleOp.getContext()); + IRRewriter rewriter(rootOp->getContext()); // A list of non-circular functions in the order in which they are analyzed // and bufferized. @@ -542,9 +549,8 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( // accurate buffer types for function return values. Functions that call // each other recursively are bufferized in an unspecified order at the end. // We may use unnecessarily "complex" (in terms of layout map) buffer types. - if (failed(getFuncOpsOrderedByCalls(moduleOp, orderedFuncOps, - remainingFuncOps, callerMap, - state.getSymbolTables()))) + if (failed(getFuncOpsOrderedByCalls(rootOp, orderedFuncOps, remainingFuncOps, + callerMap, state.getSymbolTables()))) return failure(); llvm::append_range(orderedFuncOps, remainingFuncOps); @@ -571,22 +577,27 @@ LogicalResult mlir::bufferization::bufferizeModuleOp( } // Bufferize all other ops. - for (Operation &op : llvm::make_early_inc_range(moduleOp.getOps())) { - // Functions were already bufferized. - if (isa(&op) || op.hasTrait()) - continue; - if (failed(bufferizeOp(&op, options, state, statistics))) - return failure(); + for (mlir::Region ®ion : rootOp->getRegions()) { + for (mlir::Block &block : region.getBlocks()) { + for (mlir::Operation &op : + llvm::make_early_inc_range(block.getOperations())) { + // Functions were already bufferized. + if (isa(&op) || op.hasTrait()) + continue; + if (failed(bufferizeOp(&op, options, state, statistics))) + return failure(); + } + } } // Post-pass cleanup of function argument attributes. - removeBufferizationAttributesInModule(moduleOp); + removeBufferizationAttributesInRoot(rootOp); return success(); } -LogicalResult mlir::bufferization::runOneShotModuleBufferize( - ModuleOp moduleOp, const OneShotBufferizationOptions &options, +LogicalResult mlir::bufferization::runOneShotRootBufferize( + Operation *rootOp, const OneShotBufferizationOptions &options, BufferizationState &state, BufferizationStatistics *statistics) { assert(options.bufferizeFunctionBoundaries && "expected that function boundary bufferization is activated"); @@ -594,7 +605,7 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( "invalid combination of bufferization flags"); if (!options.copyBeforeWrite) { if (options.noAnalysisFuncFilter.empty()) { - if (failed(insertTensorCopies(moduleOp, options, state, statistics))) + if (failed(insertTensorCopies(rootOp, options, state, statistics))) return failure(); } else { // FuncOps whose names are specified in options.noAnalysisFuncFilter will @@ -610,14 +621,13 @@ LogicalResult mlir::bufferization::runOneShotModuleBufferize( }; OneShotBufferizationOptions updatedOptions(options); updatedOptions.opFilter.denyOperation(analysisFilterFn); - if (failed( - insertTensorCopies(moduleOp, updatedOptions, state, statistics))) + if (failed(insertTensorCopies(rootOp, updatedOptions, state, statistics))) return failure(); } } if (options.testAnalysisOnly) return success(); - if (failed(bufferizeModuleOp(moduleOp, options, state, statistics))) + if (failed(bufferizeRootOp(moduleOp, options, state, statistics))) return failure(); return success(); } diff --git a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp index 784d95a5dd22a..84dc48fb3237d 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/TensorCopyInsertion.cpp @@ -12,7 +12,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -35,7 +35,7 @@ LogicalResult mlir::bufferization::insertTensorCopies( // analysis depending on whether function boundary bufferization is enabled or // not. if (options.bufferizeFunctionBoundaries) { - if (failed(analyzeModuleOp(cast(op), analysisState, statistics))) + if (failed(analyzeRootOp(op, analysisState, statistics))) return failure(); } else { if (failed(analyzeOp(op, analysisState, statistics))) diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp index 15e5102462ad7..5fff7da99e0a0 100644 --- a/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp +++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparsificationAndBufferizationPass.cpp @@ -13,7 +13,7 @@ #include "mlir/Dialect/Bufferization/IR/Bufferization.h" #include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" #include "mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h" -#include "mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" #include "mlir/Dialect/Bufferization/Transforms/Passes.h" #include "mlir/Dialect/Bufferization/Transforms/Transforms.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -116,12 +116,12 @@ class SparsificationAndBufferizationPass bufferization::BufferizationState bufferizationState; - if (failed(bufferization::bufferizeModuleOp(cast(getOperation()), - updatedOptions, - bufferizationState))) + if (failed(bufferization::bufferizeRootOp(getOperation(), + updatedOptions, + bufferizationState))) return failure(); - bufferization::removeBufferizationAttributesInModule(getOperation()); + bufferization::removeBufferizationAttributesInRoot(getOperation()); return success(); } diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-allow-return-allocs.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-allow-return-allocs.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-allow-return-allocs.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-analysis.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-analysis.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-analysis.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-force-copy-before-write.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-force-copy-before-write.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-force-copy-before-write.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-invalid.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-invalid.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-invalid.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-out-params.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize-out-params.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize-out-params.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize.mlir similarity index 100% rename from mlir/test/Dialect/Bufferization/Transforms/one-shot-module-bufferize.mlir rename to mlir/test/Dialect/Bufferization/Transforms/one-shot-root-bufferize.mlir diff --git a/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir new file mode 100644 index 0000000000000..25ff512a885d7 --- /dev/null +++ b/mlir/test/Dialect/Bufferization/Transforms/one-shot-root-non-module-bufferize.mlir @@ -0,0 +1,33 @@ +// RUN: mlir-opt %s -allow-unregistered-dialect -pass-pipeline='builtin.module(test.symbol_scope_isolated(test-one-shot-root-bufferize))' -split-input-file | FileCheck %s + +"test.symbol_scope_isolated"() ({ + // CHECK-LABEL: func @inner_func( + // CHECK-SAME: %[[arg0:.*]]: memref) -> (tensor, f32) { + // CHECK-NOT: copy + %f = arith.constant 1.0 : f32 + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + // CHECK: memref.store %{{.*}}, %[[arg0]] + %0 = tensor.insert %f into %t[%c0] : tensor + // CHECK: %[[load:.*]] = memref.load %[[arg0]] + %1 = tensor.extract %0[%c1] : tensor + // CHECK: return %[[arg0]], %[[load]] : memref, f32 + return %0, %1 : tensor, f32 + } + + // CHECK-LABEL: func @call_func_with_non_tensor_return( + // CHECK-SAME: %[[arg0:.*]]: memref {bufferization.writable = true}) -> (f32, tensor) { + // CHECK-NOT: alloc + // CHECK-NOT: copy + // CHECK: %[[call:.*]]:2 = call @inner_func(%[[arg0]]) + %0, %1 = call @inner_func(%t0) : (tensor) -> (tensor, f32) + // CHECK: return %[[call]]#1, %[[call]]#0 : f32, memref + return %1, %0 : f32, tensor + } + "test.finish" () : () -> () +}) : () -> () + + diff --git a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt index 226e0bb97732d..50e1e9f8e06c7 100644 --- a/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt +++ b/mlir/test/lib/Dialect/Bufferization/CMakeLists.txt @@ -1,5 +1,6 @@ # Exclude tests from libMLIR.so add_mlir_library(MLIRBufferizationTestPasses + TestOneShotRootBufferize.cpp TestTensorCopyInsertion.cpp TestTensorLikeAndBufferLike.cpp diff --git a/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp b/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp new file mode 100644 index 0000000000000..30f0b312eb1af --- /dev/null +++ b/mlir/test/lib/Dialect/Bufferization/TestOneShotRootBufferize.cpp @@ -0,0 +1,54 @@ +//===- TestOneShotRootBufferzation.cpp - Bufferization Test -----*- c++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/Bufferization/IR/Bufferization.h" +#include "mlir/Dialect/Bufferization/Transforms/Bufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/OneShotRootBufferize.h" +#include "mlir/Dialect/Bufferization/Transforms/Transforms.h" +#include "mlir/Pass/Pass.h" + +using namespace mlir; + +namespace { +struct TestOneShotRootBufferizePass + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestOneShotRootBufferizePass) + + TestOneShotRootBufferizePass() = default; + TestOneShotRootBufferizePass(const TestOneShotRootBufferizePass &pass) + : PassWrapper(pass) {} + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + StringRef getArgument() const final { return "test-one-shot-root-bufferize"; } + StringRef getDescription() const final { + return "Module pass to test One Shot Root Bufferization"; + } + + void runOnOperation() override { + + llvm::errs() << "Running TestOneShotRootBufferize on: " + << getOperation()->getName() << "\n"; + bufferization::OneShotBufferizationOptions opt; + + opt.bufferizeFunctionBoundaries = true; + bufferization::BufferizationState bufferizationState; + + if (failed(bufferization::runOneShotRootBufferize(getOperation(), opt, + bufferizationState))) + signalPassFailure(); + } +}; +} // namespace + +namespace mlir::test { +void registerTestOneShotRootBufferizePass() { + PassRegistration(); +} +} // namespace mlir::test diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index ab3f847ca2acf..2727ad34b23cb 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -125,6 +125,15 @@ def SymbolScopeOp : TEST_Op<"symbol_scope", let regions = (region SizedRegion<1>:$region); } +def SymbolScopeIsolatedOp + : TEST_Op<"symbol_scope_isolated", [IsolatedFromAbove, SymbolTable, + SingleBlockImplicitTerminator< + "TerminatorOp">]> { + let summary = + "operation which defines a new symbol table that is IsolatedFromAbove"; + let regions = (region SizedRegion<1>:$region); +} + def SymbolTableRegionOp : TEST_Op<"symbol_table_region", [SymbolTable]> { let summary = "operation which defines a new symbol table without a " "restriction on a terminator"; diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp index 143a5e8e8f8dd..b0aa5c86615e4 100644 --- a/mlir/tools/mlir-opt/mlir-opt.cpp +++ b/mlir/tools/mlir-opt/mlir-opt.cpp @@ -135,6 +135,7 @@ void registerTestMeshSimplificationsPass(); void registerTestMultiBuffering(); void registerTestNextAccessPass(); void registerTestNVGPULowerings(); +void registerTestOneShotRootBufferizePass(); void registerTestOpaqueLoc(); void registerTestOpLoweringPasses(); void registerTestPadFusion(); @@ -281,6 +282,7 @@ void registerTestPasses() { mlir::test::registerTestMultiBuffering(); mlir::test::registerTestNextAccessPass(); mlir::test::registerTestNVGPULowerings(); + mlir::test::registerTestOneShotRootBufferizePass(); mlir::test::registerTestOpaqueLoc(); mlir::test::registerTestOpLoweringPasses(); mlir::test::registerTestPadFusion();