Skip to content

Commit 6c6bba7

Browse files
[mlir][linalg][bufferize][NFC] Use RewriterBase instead of OpBuilder
This is in preparation of unifying core bufferization and Comprehensive Bufferize. Differential Revision: https://reviews.llvm.org/D116102
1 parent 3728a7d commit 6c6bba7

File tree

11 files changed

+77
-88
lines changed

11 files changed

+77
-88
lines changed

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.h

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/IR/BuiltinOps.h"
1515
#include "mlir/IR/BuiltinTypes.h"
1616
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/PatternMatch.h"
1718
#include "mlir/Support/LLVM.h"
1819
#include "llvm/ADT/EquivalenceClasses.h"
1920
#include "llvm/ADT/SetVector.h"
@@ -296,7 +297,8 @@ struct DialectBufferizationState {
296297
/// * `replaceOp` replaces an op with new values.
297298
class BufferizationState {
298299
public:
299-
BufferizationState(Operation *op, const BufferizationOptions &options);
300+
BufferizationState(Operation *op, const BufferizationOptions &options,
301+
RewriterBase &rewriter);
300302

301303
// BufferizationState should be passed as a reference.
302304
BufferizationState(const BufferizationState &) = delete;
@@ -387,9 +389,10 @@ class BufferizationState {
387389
/// Replace an op with a new op. Tensor OpResults must be replaced with memref
388390
/// values.
389391
template <typename OpTy, typename... Args>
390-
OpTy replaceOpWithNewOp(OpBuilder &b, Operation *op, Args &&...args) {
392+
OpTy replaceOpWithNewOp(RewriterBase &rewriter, Operation *op,
393+
Args &&...args) {
391394
Operation *newOp =
392-
b.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
395+
rewriter.create<OpTy>(op->getLoc(), std::forward<Args>(args)...);
393396
replaceOp(op, newOp->getResults());
394397
return cast<OpTy>(newOp);
395398
}
@@ -417,8 +420,8 @@ class BufferizationState {
417420
/// Return a reference to the BufferizationOptions.
418421
const BufferizationOptions &getOptions() const { return options; }
419422

420-
/// Return a reference to the OpBuilder.
421-
OpBuilder &getBuilder() { return builder; }
423+
/// Return a reference to the rewriter.
424+
RewriterBase &getRewriter() { return rewriter; }
422425

423426
private:
424427
friend LogicalResult
@@ -440,7 +443,7 @@ class BufferizationState {
440443
const BufferizationOptions &options;
441444

442445
/// The OpBuilder used during bufferization.
443-
OpBuilder builder;
446+
RewriterBase &rewriter;
444447
};
445448

446449
/// Bufferize all ops in the given region.
@@ -523,7 +526,7 @@ struct AllocationHoistingBarrierOnly
523526
return false;
524527
}
525528

526-
LogicalResult bufferize(Operation *op, OpBuilder &b,
529+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
527530
BufferizationState &state) const {
528531
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
529532
if (any_of(op->getOperandTypes(), isaTensor) ||

mlir/include/mlir/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
209209
}],
210210
/*retType=*/"LogicalResult",
211211
/*methodName=*/"bufferize",
212-
/*args=*/(ins "OpBuilder &":$b,
212+
/*args=*/(ins "RewriterBase &":$rewriter,
213213
"BufferizationState &":$state),
214214
/*methodBody=*/"",
215215
/*defaultImplementation=*/[{

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ArithInterfaceImpl.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ namespace arith_ext {
2323
struct ConstantOpInterface
2424
: public BufferizableOpInterface::ExternalModel<ConstantOpInterface,
2525
arith::ConstantOp> {
26-
LogicalResult bufferize(Operation *op, OpBuilder &b,
26+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
2727
BufferizationState &state) const {
2828
auto constantOp = cast<arith::ConstantOp>(op);
2929
assert(constantOp.getType().dyn_cast<RankedTensorType>() &&
@@ -35,8 +35,8 @@ struct ConstantOpInterface
3535

3636
GlobalCreator globalCreator(moduleOp);
3737
auto globalMemref = globalCreator.getGlobalFor(constantOp);
38-
state.replaceOpWithNewOp<memref::GetGlobalOp>(b, op, globalMemref.type(),
39-
globalMemref.getName());
38+
state.replaceOpWithNewOp<memref::GetGlobalOp>(
39+
rewriter, op, globalMemref.type(), globalMemref.getName());
4040
return success();
4141
}
4242

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizableOpInterface.cpp

Lines changed: 18 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -333,8 +333,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
333333
}
334334

335335
mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
336-
Operation *op, const BufferizationOptions &options)
337-
: aliasInfo(op), options(options), builder(op->getContext()) {
336+
Operation *op, const BufferizationOptions &options, RewriterBase &rewriter)
337+
: aliasInfo(op), options(options), rewriter(rewriter) {
338338
// Set up alias sets for OpResults that must bufferize in-place. This should
339339
// be done before making any other bufferization decisions.
340340
op->walk([&](BufferizableOpInterface bufferizableOp) {
@@ -361,7 +361,7 @@ mlir::linalg::comprehensive_bufferize::BufferizationState::BufferizationState(
361361
/// bufferization is necessary.
362362
Value mlir::linalg::comprehensive_bufferize::BufferizationState::
363363
getResultBuffer(OpResult result) {
364-
OpBuilder::InsertionGuard guard(builder);
364+
OpBuilder::InsertionGuard guard(rewriter);
365365
Operation *op = result.getOwner();
366366
SmallVector<OpOperand *> aliasingOperands = getAliasingOpOperand(result);
367367
assert(!aliasingOperands.empty() && "could not get aliasing OpOperand");
@@ -391,9 +391,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
391391
Location loc = op->getLoc();
392392
// Move insertion point right after `operandBuffer`. That is where the
393393
// allocation should be inserted (in the absence of allocation hoisting).
394-
setInsertionPointAfter(builder, operandBuffer);
394+
setInsertionPointAfter(rewriter, operandBuffer);
395395
// Allocate the result buffer.
396-
Value resultBuffer = createAllocDeallocPair(builder, loc, operandBuffer);
396+
Value resultBuffer = createAllocDeallocPair(rewriter, loc, operandBuffer);
397397
bool skipCopy = false;
398398
// Do not copy if the last preceding write of `operand` is an op that does
399399
// not write (skipping ops that merely create aliases). E.g., InitTensorOp.
@@ -413,8 +413,8 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
413413
skipCopy = true;
414414
if (!skipCopy) {
415415
// The copy happens right before the op that is bufferized.
416-
builder.setInsertionPoint(op);
417-
createMemCpy(builder, loc, operandBuffer, resultBuffer);
416+
rewriter.setInsertionPoint(op);
417+
createMemCpy(rewriter, loc, operandBuffer, resultBuffer);
418418
}
419419
return resultBuffer;
420420
}
@@ -425,8 +425,7 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::
425425

426426
void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
427427
Operation *op, ValueRange values) {
428-
OpBuilder &b = getBuilder();
429-
OpBuilder::InsertionGuard g(b);
428+
OpBuilder::InsertionGuard g(rewriter);
430429

431430
// Replace all OpResults with the given values.
432431
for (OpResult opResult : op->getOpResults()) {
@@ -444,14 +443,14 @@ void mlir::linalg::comprehensive_bufferize::BufferizationState::replaceOp(
444443
// The existing uses of the OpResult still expect a tensor. Insert a
445444
// ToTensorOp. Throughout bufferization, this ToTensorOp will gradually
446445
// loose all of its users and eventually DCE away.
447-
setInsertionPointAfter(b, replacement);
448-
replacement = b.create<bufferization::ToTensorOp>(replacement.getLoc(),
449-
replacement);
446+
setInsertionPointAfter(rewriter, replacement);
447+
replacement = rewriter.create<bufferization::ToTensorOp>(
448+
replacement.getLoc(), replacement);
450449
}
451450
opResult.replaceAllUsesWith(replacement);
452451
}
453452

454-
op->erase();
453+
rewriter.eraseOp(op);
455454
}
456455

457456
LogicalResult
@@ -481,7 +480,7 @@ mlir::linalg::comprehensive_bufferize::bufferize(Block *block,
481480
LogicalResult
482481
mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
483482
BufferizationState &state) {
484-
OpBuilder &b = state.getBuilder();
483+
RewriterBase &rewriter = state.getRewriter();
485484

486485
// Check if op has tensor results or operands.
487486
auto isaTensor = [](Type t) { return t.isa<TensorType>(); };
@@ -496,8 +495,8 @@ mlir::linalg::comprehensive_bufferize::bufferize(Operation *op,
496495
// Bufferize using `BufferizableOpInterface`. Interface implementations are
497496
// responsible for bufferizing nested ops.
498497
if (auto bufferizableOp = state.getOptions().dynCastBufferizableOp(op)) {
499-
b.setInsertionPoint(op);
500-
return bufferizableOp.bufferize(b, state);
498+
rewriter.setInsertionPoint(op);
499+
return bufferizableOp.bufferize(rewriter, state);
501500
}
502501

503502
// `op` is an unbufferizable tensor op.
@@ -679,10 +678,9 @@ Value mlir::linalg::comprehensive_bufferize::BufferizationState::lookupBuffer(
679678
}
680679

681680
// Insert to_memref op.
682-
OpBuilder &b = getBuilder();
683-
OpBuilder::InsertionGuard g(b);
684-
setInsertionPointAfter(b, tensor);
685-
return b.create<bufferization::ToMemrefOp>(
681+
OpBuilder::InsertionGuard g(rewriter);
682+
setInsertionPointAfter(rewriter, tensor);
683+
return rewriter.create<bufferization::ToMemrefOp>(
686684
tensor.getLoc(),
687685
getDynamicMemRefType(tensor.getType().cast<RankedTensorType>()), tensor);
688686
}

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/BufferizationInterfaceImpl.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,14 @@ struct ToMemrefOpInterface
5050
return OpResult();
5151
}
5252

53-
LogicalResult bufferize(Operation *op, OpBuilder &b,
53+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
5454
BufferizationState &state) const {
5555
auto toMemrefOp = cast<bufferization::ToMemrefOp>(op);
5656

5757
// Fold to_memref(to_tensor(x)) to x.
5858
if (auto toTensorOp =
5959
toMemrefOp.tensor().getDefiningOp<bufferization::ToTensorOp>()) {
60-
toMemrefOp.replaceAllUsesWith(toTensorOp.memref());
61-
toMemrefOp.erase();
60+
rewriter.replaceOp(toMemrefOp, toTensorOp.memref());
6261
return success();
6362
}
6463

@@ -86,7 +85,7 @@ struct ToMemrefOpInterface
8685
struct ToTensorOpInterface
8786
: public BufferizableOpInterface::ExternalModel<ToTensorOpInterface,
8887
bufferization::ToTensorOp> {
89-
LogicalResult bufferize(Operation *op, OpBuilder &b,
88+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
9089
BufferizationState &state) const {
9190
return success();
9291
}

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ComprehensiveBufferize.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,8 @@ annotateOpsWithBufferizationMarkers(Operation *op,
651651

652652
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
653653
Operation *op, std::unique_ptr<BufferizationOptions> options) {
654-
BufferizationState state(op, *options);
654+
IRRewriter rewriter(op->getContext());
655+
BufferizationState state(op, *options, rewriter);
655656
return runComprehensiveBufferize(op, *options, state);
656657
}
657658

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/LinalgInterfaceImpl.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ namespace {
2323
// TODO: Ops in the linalg dialect can directly implement this interface.
2424

2525
/// Generic conversion for any LinalgOp on tensors.
26-
static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
26+
static LogicalResult bufferizeLinalgOp(RewriterBase &rewriter, LinalgOp op,
2727
BufferizationState &state) {
2828
// Take a guard before anything else.
29-
OpBuilder::InsertionGuard g(b);
30-
b.setInsertionPoint(op);
29+
OpBuilder::InsertionGuard g(rewriter);
30+
rewriter.setInsertionPoint(op);
3131

3232
// Nothing to do. This op is already bufferized.
3333
if (op.hasBufferSemantics())
@@ -63,9 +63,9 @@ static LogicalResult bufferizeLinalgOp(OpBuilder &b, LinalgOp op,
6363
newOperands.append(newOutputBuffers.begin(), newOutputBuffers.end());
6464

6565
// Set insertion point now that potential alloc/dealloc are introduced.
66-
b.setInsertionPoint(op);
67-
auto bufferizedOp = cast<LinalgOp>(
68-
op.clone(b, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
66+
rewriter.setInsertionPoint(op);
67+
auto bufferizedOp = cast<LinalgOp>(op.clone(
68+
rewriter, op.getLoc(), /*resultTypes=*/TypeRange{}, newOperands));
6969

7070
// Replace the results of the old op with the new output buffers.
7171
state.replaceOp(op, newOutputBuffers);
@@ -177,9 +177,9 @@ struct LinalgOpInterface
177177
return BufferRelation::Equivalent;
178178
}
179179

180-
LogicalResult bufferize(Operation *op, OpBuilder &b,
180+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
181181
BufferizationState &state) const {
182-
return bufferizeLinalgOp(b, cast<LinalgOp>(op), state);
182+
return bufferizeLinalgOp(rewriter, cast<LinalgOp>(op), state);
183183
}
184184
};
185185

@@ -192,15 +192,15 @@ struct InitTensorOpInterface
192192
return false;
193193
}
194194

195-
LogicalResult bufferize(Operation *op, OpBuilder &b,
195+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
196196
BufferizationState &state) const {
197197
auto initTensorOp = cast<linalg::InitTensorOp>(op);
198198

199199
// The InitTensorOp may have been eliminated.
200200
if (initTensorOp->getUses().empty())
201201
return success();
202202

203-
Value alloc = state.createAllocDeallocPair(b, initTensorOp->getLoc(),
203+
Value alloc = state.createAllocDeallocPair(rewriter, initTensorOp->getLoc(),
204204
initTensorOp.result());
205205
state.replaceOp(op, alloc);
206206
return success();
@@ -251,15 +251,10 @@ struct TiledLoopOpInterface
251251

252252
bool isAllocationHoistingBarrier(Operation *op) const { return true; }
253253

254-
LogicalResult bufferize(Operation *op, OpBuilder &b,
254+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
255255
BufferizationState &state) const {
256256
auto tiledLoopOp = cast<linalg::TiledLoopOp>(op);
257257

258-
// Use IRRewriter instead of OpBuilder because it has additional helper
259-
// functions.
260-
IRRewriter rewriter(op->getContext());
261-
rewriter.setInsertionPoint(tiledLoopOp);
262-
263258
// Compute new inputs, outputs and results.
264259
SmallVector<Value> newInputs, newOutputs, newResults;
265260
for (Value value : tiledLoopOp.inputs()) {
@@ -358,7 +353,7 @@ struct YieldOpInterface
358353
return OpResult();
359354
}
360355

361-
LogicalResult bufferize(Operation *op, OpBuilder &b,
356+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
362357
BufferizationState &state) const {
363358
auto yieldOp = cast<linalg::YieldOp>(op);
364359

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/ModuleBufferization.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -725,7 +725,8 @@ static void annotateOpsWithBufferizationMarkers(FuncOp funcOp,
725725

726726
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
727727
ModuleOp moduleOp, std::unique_ptr<BufferizationOptions> options) {
728-
BufferizationState state(moduleOp, *options);
728+
IRRewriter rewriter(moduleOp.getContext());
729+
BufferizationState state(moduleOp, *options, rewriter);
729730
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
730731
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
731732

mlir/lib/Dialect/Linalg/ComprehensiveBufferize/SCFInterfaceImpl.cpp

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ struct ExecuteRegionOpInterface
6060
return true;
6161
}
6262

63-
LogicalResult bufferize(Operation *op, OpBuilder &b,
63+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
6464
BufferizationState &state) const {
6565
// TODO: Add bufferization support when needed. scf.execute_region should be
6666
// bufferized similar to scf.if.
@@ -135,15 +135,10 @@ struct IfOpInterface
135135
return true;
136136
}
137137

138-
LogicalResult bufferize(Operation *op, OpBuilder &b,
138+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
139139
BufferizationState &state) const {
140140
auto ifOp = cast<scf::IfOp>(op);
141141

142-
// Use IRRewriter instead of OpBuilder because it has additional helper
143-
// functions.
144-
IRRewriter rewriter(op->getContext());
145-
rewriter.setInsertionPoint(ifOp);
146-
147142
// Compute new types of the bufferized scf.if op.
148143
SmallVector<Type> newTypes;
149144
for (Type returnType : ifOp->getResultTypes()) {
@@ -276,16 +271,11 @@ struct ForOpInterface
276271
return true;
277272
}
278273

279-
LogicalResult bufferize(Operation *op, OpBuilder & /*b*/,
274+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
280275
BufferizationState &state) const {
281276
auto forOp = cast<scf::ForOp>(op);
282277
Block *oldLoopBody = &forOp.getLoopBody().front();
283278

284-
// Use IRRewriter instead of OpBuilder because it has additional helper
285-
// functions.
286-
IRRewriter rewriter(op->getContext());
287-
rewriter.setInsertionPoint(forOp);
288-
289279
// Indices of all iter_args that have tensor type. These are the ones that
290280
// are bufferized.
291281
DenseSet<int64_t> indices;
@@ -438,7 +428,7 @@ struct YieldOpInterface
438428
return OpResult();
439429
}
440430

441-
LogicalResult bufferize(Operation *op, OpBuilder &b,
431+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
442432
BufferizationState &state) const {
443433
auto yieldOp = cast<scf::YieldOp>(op);
444434
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp>(

0 commit comments

Comments
 (0)