Skip to content

Commit 588845d

Browse files
authored
[mlir][NFC] update mlir/Dialect create APIs (20/n) (llvm#149927)
See llvm#147168 for more info.
1 parent d5d8eaf commit 588845d

20 files changed

+386
-373
lines changed

mlir/lib/Dialect/SCF/IR/SCF.cpp

Lines changed: 69 additions & 66 deletions
Large diffs are not rendered by default.

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
160160
OpBuilder::InsertionGuard g(b);
161161
b.setInsertionPoint(op);
162162
scf::ExecuteRegionOp executeRegionOp =
163-
b.create<scf::ExecuteRegionOp>(op->getLoc(), op->getResultTypes());
163+
scf::ExecuteRegionOp::create(b, op->getLoc(), op->getResultTypes());
164164
{
165165
OpBuilder::InsertionGuard g(b);
166166
b.setInsertionPointToStart(&executeRegionOp.getRegion().emplaceBlock());
@@ -169,7 +169,7 @@ static scf::ExecuteRegionOp wrapInExecuteRegion(RewriterBase &b,
169169
assert(clonedRegion.empty() && "expected empty region");
170170
b.inlineRegionBefore(op->getRegions().front(), clonedRegion,
171171
clonedRegion.end());
172-
b.create<scf::YieldOp>(op->getLoc(), clonedOp->getResults());
172+
scf::YieldOp::create(b, op->getLoc(), clonedOp->getResults());
173173
}
174174
b.replaceOp(op, executeRegionOp.getResults());
175175
return executeRegionOp;

mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ static Value castBuffer(OpBuilder &b, Value buffer, Type type) {
4141
// iter_arg's layout map must be changed (see uses of `castBuffer`).
4242
assert(memref::CastOp::areCastCompatible(buffer.getType(), type) &&
4343
"scf.while op bufferization: cast incompatible");
44-
return b.create<memref::CastOp>(buffer.getLoc(), type, buffer).getResult();
44+
return memref::CastOp::create(b, buffer.getLoc(), type, buffer).getResult();
4545
}
4646

4747
/// Helper function for loop bufferization. Return "true" if the given value
@@ -189,7 +189,7 @@ struct ExecuteRegionOpInterface
189189

190190
// Create new op and move over region.
191191
auto newOp =
192-
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
192+
scf::ExecuteRegionOp::create(rewriter, op->getLoc(), newResultTypes);
193193
newOp.getRegion().takeBody(executeRegionOp.getRegion());
194194

195195
// Bufferize every block.
@@ -203,8 +203,8 @@ struct ExecuteRegionOpInterface
203203
SmallVector<Value> newResults;
204204
for (const auto &it : llvm::enumerate(executeRegionOp->getResultTypes())) {
205205
if (isa<TensorType>(it.value())) {
206-
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
207-
executeRegionOp.getLoc(), it.value(),
206+
newResults.push_back(bufferization::ToTensorOp::create(
207+
rewriter, executeRegionOp.getLoc(), it.value(),
208208
newOp->getResult(it.index())));
209209
} else {
210210
newResults.push_back(newOp->getResult(it.index()));
@@ -258,9 +258,9 @@ struct IfOpInterface
258258

259259
// Create new op.
260260
rewriter.setInsertionPoint(ifOp);
261-
auto newIfOp =
262-
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
263-
/*withElseRegion=*/true);
261+
auto newIfOp = scf::IfOp::create(rewriter, ifOp.getLoc(), newTypes,
262+
ifOp.getCondition(),
263+
/*withElseRegion=*/true);
264264

265265
// Move over then/else blocks.
266266
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
@@ -372,9 +372,9 @@ struct IndexSwitchOpInterface
372372

373373
// Create new op.
374374
rewriter.setInsertionPoint(switchOp);
375-
auto newSwitchOp = rewriter.create<scf::IndexSwitchOp>(
376-
switchOp.getLoc(), newTypes, switchOp.getArg(), switchOp.getCases(),
377-
switchOp.getCases().size());
375+
auto newSwitchOp = scf::IndexSwitchOp::create(
376+
rewriter, switchOp.getLoc(), newTypes, switchOp.getArg(),
377+
switchOp.getCases(), switchOp.getCases().size());
378378

379379
// Move over blocks.
380380
for (auto [src, dest] :
@@ -767,8 +767,8 @@ struct ForOpInterface
767767
}
768768

769769
// Construct a new scf.for op with memref instead of tensor values.
770-
auto newForOp = rewriter.create<scf::ForOp>(
771-
forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
770+
auto newForOp = scf::ForOp::create(
771+
rewriter, forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
772772
forOp.getStep(), castedInitArgs);
773773
newForOp->setAttrs(forOp->getAttrs());
774774
Block *loopBody = newForOp.getBody();
@@ -1003,8 +1003,8 @@ struct WhileOpInterface
10031003
// Construct a new scf.while op with memref instead of tensor values.
10041004
ValueRange argsRangeBefore(castedInitArgs);
10051005
TypeRange argsTypesBefore(argsRangeBefore);
1006-
auto newWhileOp = rewriter.create<scf::WhileOp>(
1007-
whileOp.getLoc(), argsTypesAfter, castedInitArgs);
1006+
auto newWhileOp = scf::WhileOp::create(rewriter, whileOp.getLoc(),
1007+
argsTypesAfter, castedInitArgs);
10081008

10091009
// Add before/after regions to the new op.
10101010
SmallVector<Location> bbArgLocsBefore(castedInitArgs.size(),
@@ -1263,17 +1263,17 @@ struct ForallOpInterface
12631263
forallOp.getBody()->getArguments().drop_front(rank), buffers)) {
12641264
BlockArgument bbArg = std::get<0>(it);
12651265
Value buffer = std::get<1>(it);
1266-
Value bufferAsTensor = rewriter.create<ToTensorOp>(
1267-
forallOp.getLoc(), bbArg.getType(), buffer);
1266+
Value bufferAsTensor = ToTensorOp::create(rewriter, forallOp.getLoc(),
1267+
bbArg.getType(), buffer);
12681268
bbArg.replaceAllUsesWith(bufferAsTensor);
12691269
}
12701270

12711271
// Create new ForallOp without any results and drop the automatically
12721272
// introduced terminator.
12731273
rewriter.setInsertionPoint(forallOp);
12741274
ForallOp newForallOp;
1275-
newForallOp = rewriter.create<ForallOp>(
1276-
forallOp.getLoc(), forallOp.getMixedLowerBound(),
1275+
newForallOp = ForallOp::create(
1276+
rewriter, forallOp.getLoc(), forallOp.getMixedLowerBound(),
12771277
forallOp.getMixedUpperBound(), forallOp.getMixedStep(),
12781278
/*outputs=*/ValueRange(), forallOp.getMapping());
12791279

mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -50,19 +50,19 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
5050
SmallVector<Value> initArgs;
5151
initArgs.push_back(forOp.getLowerBound());
5252
llvm::append_range(initArgs, forOp.getInitArgs());
53-
auto whileOp = rewriter.create<WhileOp>(forOp.getLoc(), lcvTypes, initArgs,
54-
forOp->getAttrs());
53+
auto whileOp = WhileOp::create(rewriter, forOp.getLoc(), lcvTypes, initArgs,
54+
forOp->getAttrs());
5555

5656
// 'before' region contains the loop condition and forwarding of iteration
5757
// arguments to the 'after' region.
5858
auto *beforeBlock = rewriter.createBlock(
5959
&whileOp.getBefore(), whileOp.getBefore().begin(), lcvTypes, lcvLocs);
6060
rewriter.setInsertionPointToStart(whileOp.getBeforeBody());
61-
auto cmpOp = rewriter.create<arith::CmpIOp>(
62-
whileOp.getLoc(), arith::CmpIPredicate::slt,
61+
auto cmpOp = arith::CmpIOp::create(
62+
rewriter, whileOp.getLoc(), arith::CmpIPredicate::slt,
6363
beforeBlock->getArgument(0), forOp.getUpperBound());
64-
rewriter.create<scf::ConditionOp>(whileOp.getLoc(), cmpOp.getResult(),
65-
beforeBlock->getArguments());
64+
scf::ConditionOp::create(rewriter, whileOp.getLoc(), cmpOp.getResult(),
65+
beforeBlock->getArguments());
6666

6767
// Inline for-loop body into an executeRegion operation in the "after"
6868
// region. The return type of the execRegionOp does not contain the
@@ -72,8 +72,9 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
7272

7373
// Add induction variable incrementation
7474
rewriter.setInsertionPointToEnd(afterBlock);
75-
auto ivIncOp = rewriter.create<arith::AddIOp>(
76-
whileOp.getLoc(), afterBlock->getArgument(0), forOp.getStep());
75+
auto ivIncOp =
76+
arith::AddIOp::create(rewriter, whileOp.getLoc(),
77+
afterBlock->getArgument(0), forOp.getStep());
7778

7879
// Rewrite uses of the for-loop block arguments to the new while-loop
7980
// "after" arguments

mlir/lib/Dialect/SCF/Transforms/ForallToParallel.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ LogicalResult mlir::scf::forallToParallelLoop(RewriterBase &rewriter,
4040
SmallVector<Value> steps = forallOp.getStep(rewriter);
4141

4242
// Create empty scf.parallel op.
43-
auto parallelOp = rewriter.create<scf::ParallelOp>(loc, lbs, ubs, steps);
43+
auto parallelOp = scf::ParallelOp::create(rewriter, loc, lbs, ubs, steps);
4444
rewriter.eraseBlock(&parallelOp.getRegion().front());
4545
rewriter.inlineRegionBefore(forallOp.getRegion(), parallelOp.getRegion(),
4646
parallelOp.getRegion().begin());

mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp

Lines changed: 64 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -279,25 +279,25 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
279279
if (dynamicLoop) {
280280
Type t = ub.getType();
281281
// pred = ub > lb + (i * step)
282-
Value iv = rewriter.create<arith::AddIOp>(
283-
loc, lb,
284-
rewriter.create<arith::MulIOp>(
285-
loc, step,
286-
rewriter.create<arith::ConstantOp>(
287-
loc, rewriter.getIntegerAttr(t, i))));
288-
predicates[i] = rewriter.create<arith::CmpIOp>(
289-
loc, arith::CmpIPredicate::slt, iv, ub);
282+
Value iv = arith::AddIOp::create(
283+
rewriter, loc, lb,
284+
arith::MulIOp::create(
285+
rewriter, loc, step,
286+
arith::ConstantOp::create(rewriter, loc,
287+
rewriter.getIntegerAttr(t, i))));
288+
predicates[i] = arith::CmpIOp::create(rewriter, loc,
289+
arith::CmpIPredicate::slt, iv, ub);
290290
}
291291

292292
// special handling for induction variable as the increment is implicit.
293293
// iv = lb + i * step
294294
Type t = lb.getType();
295-
Value iv = rewriter.create<arith::AddIOp>(
296-
loc, lb,
297-
rewriter.create<arith::MulIOp>(
298-
loc, step,
299-
rewriter.create<arith::ConstantOp>(loc,
300-
rewriter.getIntegerAttr(t, i))));
295+
Value iv = arith::AddIOp::create(
296+
rewriter, loc, lb,
297+
arith::MulIOp::create(
298+
rewriter, loc, step,
299+
arith::ConstantOp::create(rewriter, loc,
300+
rewriter.getIntegerAttr(t, i))));
301301
setValueMapping(forOp.getInductionVar(), iv, i);
302302
for (Operation *op : opOrder) {
303303
if (stages[op] > i)
@@ -332,8 +332,8 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
332332
Value prevValue = valueMapping
333333
[forOp.getRegionIterArgs()[operand.getOperandNumber()]]
334334
[i - stages[op]];
335-
source = rewriter.create<arith::SelectOp>(
336-
loc, predicates[predicateIdx], source, prevValue);
335+
source = arith::SelectOp::create(
336+
rewriter, loc, predicates[predicateIdx], source, prevValue);
337337
}
338338
setValueMapping(forOp.getRegionIterArgs()[operand.getOperandNumber()],
339339
source, i - stages[op] + 1);
@@ -444,15 +444,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
444444
Type t = ub.getType();
445445
Location loc = forOp.getLoc();
446446
// newUb = ub - maxStage * step
447-
Value maxStageValue = rewriter.create<arith::ConstantOp>(
448-
loc, rewriter.getIntegerAttr(t, maxStage));
447+
Value maxStageValue = arith::ConstantOp::create(
448+
rewriter, loc, rewriter.getIntegerAttr(t, maxStage));
449449
Value maxStageByStep =
450-
rewriter.create<arith::MulIOp>(loc, step, maxStageValue);
451-
newUb = rewriter.create<arith::SubIOp>(loc, ub, maxStageByStep);
450+
arith::MulIOp::create(rewriter, loc, step, maxStageValue);
451+
newUb = arith::SubIOp::create(rewriter, loc, ub, maxStageByStep);
452452
}
453453
auto newForOp =
454-
rewriter.create<scf::ForOp>(forOp.getLoc(), forOp.getLowerBound(), newUb,
455-
forOp.getStep(), newLoopArg);
454+
scf::ForOp::create(rewriter, forOp.getLoc(), forOp.getLowerBound(), newUb,
455+
forOp.getStep(), newLoopArg);
456456
// When there are no iter args, the loop body terminator will be created.
457457
// Since we always create it below, remove the terminator if it was created.
458458
if (!newForOp.getBody()->empty())
@@ -483,16 +483,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
483483
Type t = ub.getType();
484484
for (unsigned i = 0; i < maxStage; i++) {
485485
// c = ub - (maxStage - i) * step
486-
Value c = rewriter.create<arith::SubIOp>(
487-
loc, ub,
488-
rewriter.create<arith::MulIOp>(
489-
loc, step,
490-
rewriter.create<arith::ConstantOp>(
491-
loc, rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
492-
493-
Value pred = rewriter.create<arith::CmpIOp>(
494-
newForOp.getLoc(), arith::CmpIPredicate::slt,
495-
newForOp.getInductionVar(), c);
486+
Value c = arith::SubIOp::create(
487+
rewriter, loc, ub,
488+
arith::MulIOp::create(
489+
rewriter, loc, step,
490+
arith::ConstantOp::create(
491+
rewriter, loc,
492+
rewriter.getIntegerAttr(t, int64_t(maxStage - i)))));
493+
494+
Value pred = arith::CmpIOp::create(rewriter, newForOp.getLoc(),
495+
arith::CmpIPredicate::slt,
496+
newForOp.getInductionVar(), c);
496497
predicates[i] = pred;
497498
}
498499
}
@@ -515,13 +516,13 @@ LogicalResult LoopPipelinerInternal::createKernel(
515516

516517
// offset = (maxStage - stages[op]) * step
517518
Type t = step.getType();
518-
Value offset = rewriter.create<arith::MulIOp>(
519-
forOp.getLoc(), step,
520-
rewriter.create<arith::ConstantOp>(
521-
forOp.getLoc(),
519+
Value offset = arith::MulIOp::create(
520+
rewriter, forOp.getLoc(), step,
521+
arith::ConstantOp::create(
522+
rewriter, forOp.getLoc(),
522523
rewriter.getIntegerAttr(t, maxStage - stages[op])));
523-
Value iv = rewriter.create<arith::AddIOp>(
524-
forOp.getLoc(), newForOp.getInductionVar(), offset);
524+
Value iv = arith::AddIOp::create(rewriter, forOp.getLoc(),
525+
newForOp.getInductionVar(), offset);
525526
nestedNewOp->setOperand(operand->getOperandNumber(), iv);
526527
rewriter.setInsertionPointAfter(newOp);
527528
continue;
@@ -594,8 +595,8 @@ LogicalResult LoopPipelinerInternal::createKernel(
594595
auto defStage = stages.find(def);
595596
if (defStage != stages.end() && defStage->second < maxStage) {
596597
Value pred = predicates[defStage->second];
597-
source = rewriter.create<arith::SelectOp>(
598-
pred.getLoc(), pred, source,
598+
source = arith::SelectOp::create(
599+
rewriter, pred.getLoc(), pred, source,
599600
newForOp.getBody()
600601
->getArguments()[yieldOperand.getOperandNumber() + 1]);
601602
}
@@ -638,7 +639,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
638639
maxStage - defStage->second + 1);
639640
}
640641
}
641-
rewriter.create<scf::YieldOp>(forOp.getLoc(), yieldOperands);
642+
scf::YieldOp::create(rewriter, forOp.getLoc(), yieldOperands);
642643
return success();
643644
}
644645

@@ -652,51 +653,53 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
652653
// removed by dead code if not used.
653654

654655
auto createConst = [&](int v) {
655-
return rewriter.create<arith::ConstantOp>(loc,
656-
rewriter.getIntegerAttr(t, v));
656+
return arith::ConstantOp::create(rewriter, loc,
657+
rewriter.getIntegerAttr(t, v));
657658
};
658659

659660
// total_iterations = cdiv(range_diff, step);
660661
// - range_diff = ub - lb
661662
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662663
Value zero = createConst(0);
663664
Value one = createConst(1);
664-
Value stepLessZero = rewriter.create<arith::CmpIOp>(
665-
loc, arith::CmpIPredicate::slt, step, zero);
666-
Value stepDecr =
667-
rewriter.create<arith::SelectOp>(loc, stepLessZero, one, createConst(-1));
665+
Value stepLessZero = arith::CmpIOp::create(
666+
rewriter, loc, arith::CmpIPredicate::slt, step, zero);
667+
Value stepDecr = arith::SelectOp::create(rewriter, loc, stepLessZero, one,
668+
createConst(-1));
668669

669-
Value rangeDiff = rewriter.create<arith::SubIOp>(loc, ub, lb);
670-
Value rangeIncrStep = rewriter.create<arith::AddIOp>(loc, rangeDiff, step);
670+
Value rangeDiff = arith::SubIOp::create(rewriter, loc, ub, lb);
671+
Value rangeIncrStep = arith::AddIOp::create(rewriter, loc, rangeDiff, step);
671672
Value rangeDecr =
672-
rewriter.create<arith::AddIOp>(loc, rangeIncrStep, stepDecr);
673-
Value totalIterations = rewriter.create<arith::DivSIOp>(loc, rangeDecr, step);
673+
arith::AddIOp::create(rewriter, loc, rangeIncrStep, stepDecr);
674+
Value totalIterations =
675+
arith::DivSIOp::create(rewriter, loc, rangeDecr, step);
674676

675677
// If total_iters < max_stage, start the epilogue at zero to match the
676678
// ramp-up in the prologue.
677679
// start_iter = max(0, total_iters - max_stage)
678-
Value iterI = rewriter.create<arith::SubIOp>(loc, totalIterations,
679-
createConst(maxStage));
680-
iterI = rewriter.create<arith::MaxSIOp>(loc, zero, iterI);
680+
Value iterI = arith::SubIOp::create(rewriter, loc, totalIterations,
681+
createConst(maxStage));
682+
iterI = arith::MaxSIOp::create(rewriter, loc, zero, iterI);
681683

682684
// Capture predicates for dynamic loops.
683685
SmallVector<Value> predicates(maxStage + 1);
684686

685687
for (int64_t i = 1; i <= maxStage; i++) {
686688
// newLastIter = lb + step * iterI
687-
Value newlastIter = rewriter.create<arith::AddIOp>(
688-
loc, lb, rewriter.create<arith::MulIOp>(loc, step, iterI));
689+
Value newlastIter = arith::AddIOp::create(
690+
rewriter, loc, lb, arith::MulIOp::create(rewriter, loc, step, iterI));
689691

690692
setValueMapping(forOp.getInductionVar(), newlastIter, i);
691693

692694
// increment to next iterI
693-
iterI = rewriter.create<arith::AddIOp>(loc, iterI, one);
695+
iterI = arith::AddIOp::create(rewriter, loc, iterI, one);
694696

695697
if (dynamicLoop) {
696698
// Disable stages when `i` is greater than total_iters.
697699
// pred = total_iters >= i
698-
predicates[i] = rewriter.create<arith::CmpIOp>(
699-
loc, arith::CmpIPredicate::sge, totalIterations, createConst(i));
700+
predicates[i] =
701+
arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::sge,
702+
totalIterations, createConst(i));
700703
}
701704
}
702705

@@ -758,8 +761,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
758761
unsigned nextVersion = currentVersion + 1;
759762
Value pred = predicates[currentVersion];
760763
Value prevValue = valueMapping[mapVal][currentVersion];
761-
auto selOp = rewriter.create<arith::SelectOp>(loc, pred, pair.value(),
762-
prevValue);
764+
auto selOp = arith::SelectOp::create(rewriter, loc, pred,
765+
pair.value(), prevValue);
763766
returnValues[ri] = selOp;
764767
if (nextVersion <= maxStage)
765768
setValueMapping(mapVal, selOp, nextVersion);

0 commit comments

Comments
 (0)