Skip to content

Commit 67fc166

Browse files
authored
[MLIR] Add bufferization state class to OneShotBufferization pass (#138143)
This PR is a follow-up on #138125, and adds a bufferization state class providing information about the IR. The information currently consists of a cached list of symbol tables, which aims to solve the quadratic scaling of the bufferization task with respect to the number of symbols. The PR breaks API compatibility: the `bufferize` method of the `BufferizableOpInterface` has been enriched with a reference to a `BufferizationState` object. The bufferization state must be kept in a valid state by the interface implementations. For example, if an operation with the `Symbol` trait is inserted or replaced, its parent `SymbolTable` must be updated accordingly (see, for example, the bufferization of `arith::ConstantOp`, where the symbol table of the module gets the new global symbol inserted). Similarly, the invalidation of a symbol table must be performed if an operation with the `SymbolTable` trait is removed (this can be performed using the `invalidateSymbolTable` method, introduced in #138014).
1 parent 4fdcde5 commit 67fc166

27 files changed

+214
-86
lines changed

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,20 @@ class AnalysisState {
578578
insideMutuallyExclusiveRegionsCache;
579579
};
580580

581+
/// BufferizationState provides information about the state of the IR during the
582+
/// bufferization process.
583+
class BufferizationState {
584+
public:
585+
/// Get a reference to the collection of cached symbol tables.
586+
SymbolTableCollection &getSymbolTables();
587+
588+
private:
589+
/// The cached symbol tables.
590+
/// The user is expected to update / invalidate the cached symbol tables if
591+
/// the bufferized operation has the Symbol or SymbolTable traits.
592+
SymbolTableCollection symbolTables;
593+
};
594+
581595
/// Create an AllocTensorOp for the given shaped value (memref or tensor).
582596
/// If `copy` is set, the shaped value is copied. Otherwise, a tensor with
583597
/// undefined contents is allocated.

mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
426426
/*retType=*/"::llvm::LogicalResult",
427427
/*methodName=*/"bufferize",
428428
/*args=*/(ins "::mlir::RewriterBase &":$rewriter,
429-
"const ::mlir::bufferization::BufferizationOptions &":$options),
429+
"const ::mlir::bufferization::BufferizationOptions &":$options,
430+
"::mlir::bufferization::BufferizationState &":$state),
430431
/*methodBody=*/"",
431432
/*defaultImplementation=*/[{
432433
llvm_unreachable("bufferize not implemented");

mlir/include/mlir/Dialect/Bufferization/IR/BufferizationOps.td

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,8 @@ def Bufferization_AllocTensorOp : Bufferization_Op<"alloc_tensor",
9393

9494
let extraClassDeclaration = [{
9595
LogicalResult bufferize(RewriterBase &rewriter,
96-
const BufferizationOptions &options);
96+
const BufferizationOptions &options,
97+
BufferizationState &state);
9798

9899
bool resultBufferizesToMemoryWrite(OpResult opResult,
99100
const AnalysisState &state);
@@ -282,7 +283,8 @@ def Bufferization_MaterializeInDestinationOp
282283

283284
let extraClassDeclaration = [{
284285
LogicalResult bufferize(RewriterBase &rewriter,
285-
const BufferizationOptions &options);
286+
const BufferizationOptions &options,
287+
BufferizationState &state);
286288

287289
bool bufferizesToMemoryRead(OpOperand &opOperand,
288290
const AnalysisState &state);
@@ -375,7 +377,8 @@ def Bufferization_DeallocTensorOp : Bufferization_Op<"dealloc_tensor",
375377
}
376378

377379
LogicalResult bufferize(RewriterBase &rewriter,
378-
const BufferizationOptions &options);
380+
const BufferizationOptions &options,
381+
BufferizationState &state);
379382
}];
380383
}
381384

@@ -458,7 +461,8 @@ def Bufferization_ToTensorOp : Bufferization_Op<"to_tensor", [
458461
//===------------------------------------------------------------------===//
459462

460463
LogicalResult bufferize(RewriterBase &rewriter,
461-
const BufferizationOptions &options) const {
464+
const BufferizationOptions &options,
465+
BufferizationState &state) const {
462466
// to_tensor/to_buffer pairs fold away after bufferization.
463467
return success();
464468
}
@@ -550,7 +554,8 @@ def Bufferization_ToBufferOp : Bufferization_Op<"to_buffer", [
550554
}
551555

552556
LogicalResult bufferize(RewriterBase &rewriter,
553-
const BufferizationOptions &options);
557+
const BufferizationOptions &options,
558+
BufferizationState &state);
554559
}];
555560

556561
let assemblyFormat = [{

mlir/include/mlir/Dialect/Bufferization/Transforms/BufferUtils.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GlobalOp;
2929
} // namespace memref
3030

3131
namespace bufferization {
32+
class BufferizationState;
3233

3334
/// A simple analysis that detects allocation operations.
3435
class BufferPlacementAllocs {
@@ -122,9 +123,14 @@ class BufferPlacementTransformationBase {
122123
// Globals are created lazily at the top of the enclosing ModuleOp with pretty
123124
// names. Duplicates are avoided.
124125
FailureOr<memref::GlobalOp> getGlobalFor(arith::ConstantOp constantOp,
126+
SymbolTableCollection &symbolTables,
125127
uint64_t alignment,
126128
Attribute memorySpace = {});
127129

130+
void removeSymbol(Operation *op, BufferizationState &state);
131+
132+
void insertSymbol(Operation *op, BufferizationState &state);
133+
128134
} // namespace bufferization
129135
} // namespace mlir
130136

mlir/include/mlir/Dialect/Bufferization/Transforms/Bufferize.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ struct BufferizationStatistics {
4545
/// additional buffer copies or set "options.copyBeforeWrite = true". The
4646
/// general bufferization entry point is `runOneShotBufferize`.
4747
LogicalResult bufferizeOp(Operation *op, const BufferizationOptions &options,
48+
BufferizationState &bufferizationState,
4849
BufferizationStatistics *statistics = nullptr);
4950

5051
/// Bufferize the signature of `block` and its callers (i.e., ops that have the

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotAnalysis.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,6 +270,7 @@ LogicalResult analyzeOp(Operation *op, OneShotAnalysisState &state,
270270
/// Run One-Shot Bufferize on the given op: Analysis + Bufferization
271271
LogicalResult
272272
runOneShotBufferize(Operation *op, const OneShotBufferizationOptions &options,
273+
BufferizationState &state,
273274
BufferizationStatistics *statistics = nullptr);
274275

275276
} // namespace bufferization

mlir/include/mlir/Dialect/Bufferization/Transforms/OneShotModuleBufferize.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ namespace bufferization {
2020
struct BufferizationStatistics;
2121
class OneShotAnalysisState;
2222
struct OneShotBufferizationOptions;
23+
class BufferizationState;
2324

2425
/// Analyze `moduleOp` and its nested ops. Bufferization decisions are stored in
2526
/// `state`.
@@ -38,6 +39,7 @@ analyzeModuleOp(ModuleOp moduleOp, OneShotAnalysisState &state,
3839
/// will be inserted only to these FuncOps.
3940
llvm::LogicalResult
4041
bufferizeModuleOp(ModuleOp moduleOp, const OneShotBufferizationOptions &options,
42+
BufferizationState &state,
4143
BufferizationStatistics *statistics = nullptr);
4244

4345
/// Remove bufferization attributes on every FuncOp arguments in the ModuleOp.
@@ -50,7 +52,7 @@ void removeBufferizationAttributesInModule(ModuleOp moduleOp);
5052
llvm::LogicalResult runOneShotModuleBufferize(
5153
ModuleOp moduleOp,
5254
const bufferization::OneShotBufferizationOptions &options,
53-
BufferizationStatistics *statistics = nullptr);
55+
BufferizationState &state, BufferizationStatistics *statistics = nullptr);
5456

5557
} // namespace bufferization
5658
} // namespace mlir

mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ namespace mlir {
3030
namespace bufferization {
3131
class AllocTensorOp;
3232
class OneShotAnalysisState;
33+
class BufferizationState;
3334
} // namespace bufferization
3435

3536
namespace linalg {

mlir/lib/Dialect/Arith/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
2626
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
27-
const BufferizationOptions &options) const {
27+
const BufferizationOptions &options,
28+
BufferizationState &state) const {
2829
auto constantOp = cast<arith::ConstantOp>(op);
2930
auto type = dyn_cast<RankedTensorType>(constantOp.getType());
3031

@@ -46,7 +47,8 @@ struct ConstantOpInterface
4647
// Create global memory segment and replace tensor with memref pointing to
4748
// that memory segment.
4849
FailureOr<memref::GlobalOp> globalOp =
49-
getGlobalFor(constantOp, options.bufferAlignment, memorySpace);
50+
getGlobalFor(constantOp, state.getSymbolTables(),
51+
options.bufferAlignment, memorySpace);
5052
if (failed(globalOp))
5153
return failure();
5254
memref::GlobalOp globalMemref = *globalOp;
@@ -83,7 +85,8 @@ struct IndexCastOpInterface
8385
}
8486

8587
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
86-
const BufferizationOptions &options) const {
88+
const BufferizationOptions &options,
89+
BufferizationState &state) const {
8790
auto castOp = cast<arith::IndexCastOp>(op);
8891
auto resultTensorType = cast<TensorType>(castOp.getType());
8992

@@ -131,7 +134,8 @@ struct SelectOpInterface
131134
}
132135

133136
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
134-
const BufferizationOptions &options) const {
137+
const BufferizationOptions &options,
138+
BufferizationState &state) const {
135139
auto selectOp = cast<arith::SelectOp>(op);
136140
Location loc = selectOp.getLoc();
137141

mlir/lib/Dialect/Bufferization/IR/BufferizableOpInterface.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,10 @@ void AnalysisState::resetCache() {
125125
insideMutuallyExclusiveRegionsCache.clear();
126126
}
127127

128+
SymbolTableCollection &BufferizationState::getSymbolTables() {
129+
return symbolTables;
130+
}
131+
128132
Region *bufferization::getNextEnclosingRepetitiveRegion(
129133
Region *region, const BufferizationOptions &options) {
130134
assert(isRepetitiveRegion(region, options) && "expected repetitive region");

0 commit comments

Comments
 (0)