@@ -279,25 +279,25 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
279
279
if (dynamicLoop) {
280
280
Type t = ub.getType ();
281
281
// pred = ub > lb + (i * step)
282
- Value iv = rewriter. create < arith::AddIOp> (
283
- loc, lb,
284
- rewriter. create < arith::MulIOp> (
285
- loc, step,
286
- rewriter. create < arith::ConstantOp>(
287
- loc, rewriter.getIntegerAttr (t, i))));
288
- predicates[i] = rewriter. create < arith::CmpIOp>(
289
- loc, arith::CmpIPredicate::slt, iv, ub);
282
+ Value iv = arith::AddIOp::create (
283
+ rewriter, loc, lb,
284
+ arith::MulIOp::create (
285
+ rewriter, loc, step,
286
+ arith::ConstantOp::create (rewriter, loc,
287
+ rewriter.getIntegerAttr (t, i))));
288
+ predicates[i] = arith::CmpIOp::create (rewriter, loc,
289
+ arith::CmpIPredicate::slt, iv, ub);
290
290
}
291
291
292
292
// special handling for induction variable as the increment is implicit.
293
293
// iv = lb + i * step
294
294
Type t = lb.getType ();
295
- Value iv = rewriter. create < arith::AddIOp> (
296
- loc, lb,
297
- rewriter. create < arith::MulIOp> (
298
- loc, step,
299
- rewriter. create < arith::ConstantOp>( loc,
300
- rewriter.getIntegerAttr (t, i))));
295
+ Value iv = arith::AddIOp::create (
296
+ rewriter, loc, lb,
297
+ arith::MulIOp::create (
298
+ rewriter, loc, step,
299
+ arith::ConstantOp::create (rewriter, loc,
300
+ rewriter.getIntegerAttr (t, i))));
301
301
setValueMapping (forOp.getInductionVar (), iv, i);
302
302
for (Operation *op : opOrder) {
303
303
if (stages[op] > i)
@@ -332,8 +332,8 @@ LogicalResult LoopPipelinerInternal::emitPrologue(RewriterBase &rewriter) {
332
332
Value prevValue = valueMapping
333
333
[forOp.getRegionIterArgs ()[operand.getOperandNumber ()]]
334
334
[i - stages[op]];
335
- source = rewriter. create < arith::SelectOp> (
336
- loc, predicates[predicateIdx], source, prevValue);
335
+ source = arith::SelectOp::create (
336
+ rewriter, loc, predicates[predicateIdx], source, prevValue);
337
337
}
338
338
setValueMapping (forOp.getRegionIterArgs ()[operand.getOperandNumber ()],
339
339
source, i - stages[op] + 1 );
@@ -444,15 +444,15 @@ scf::ForOp LoopPipelinerInternal::createKernelLoop(
444
444
Type t = ub.getType ();
445
445
Location loc = forOp.getLoc ();
446
446
// newUb = ub - maxStage * step
447
- Value maxStageValue = rewriter. create < arith::ConstantOp> (
448
- loc, rewriter.getIntegerAttr (t, maxStage));
447
+ Value maxStageValue = arith::ConstantOp::create (
448
+ rewriter, loc, rewriter.getIntegerAttr (t, maxStage));
449
449
Value maxStageByStep =
450
- rewriter. create < arith::MulIOp>( loc, step, maxStageValue);
451
- newUb = rewriter. create < arith::SubIOp>( loc, ub, maxStageByStep);
450
+ arith::MulIOp::create (rewriter, loc, step, maxStageValue);
451
+ newUb = arith::SubIOp::create (rewriter, loc, ub, maxStageByStep);
452
452
}
453
453
auto newForOp =
454
- rewriter. create < scf::ForOp>( forOp.getLoc (), forOp.getLowerBound (), newUb,
455
- forOp.getStep (), newLoopArg);
454
+ scf::ForOp::create (rewriter, forOp.getLoc (), forOp.getLowerBound (), newUb,
455
+ forOp.getStep (), newLoopArg);
456
456
// When there are no iter args, the loop body terminator will be created.
457
457
// Since we always create it below, remove the terminator if it was created.
458
458
if (!newForOp.getBody ()->empty ())
@@ -483,16 +483,17 @@ LogicalResult LoopPipelinerInternal::createKernel(
483
483
Type t = ub.getType ();
484
484
for (unsigned i = 0 ; i < maxStage; i++) {
485
485
// c = ub - (maxStage - i) * step
486
- Value c = rewriter.create <arith::SubIOp>(
487
- loc, ub,
488
- rewriter.create <arith::MulIOp>(
489
- loc, step,
490
- rewriter.create <arith::ConstantOp>(
491
- loc, rewriter.getIntegerAttr (t, int64_t (maxStage - i)))));
492
-
493
- Value pred = rewriter.create <arith::CmpIOp>(
494
- newForOp.getLoc (), arith::CmpIPredicate::slt,
495
- newForOp.getInductionVar (), c);
486
+ Value c = arith::SubIOp::create (
487
+ rewriter, loc, ub,
488
+ arith::MulIOp::create (
489
+ rewriter, loc, step,
490
+ arith::ConstantOp::create (
491
+ rewriter, loc,
492
+ rewriter.getIntegerAttr (t, int64_t (maxStage - i)))));
493
+
494
+ Value pred = arith::CmpIOp::create (rewriter, newForOp.getLoc (),
495
+ arith::CmpIPredicate::slt,
496
+ newForOp.getInductionVar (), c);
496
497
predicates[i] = pred;
497
498
}
498
499
}
@@ -515,13 +516,13 @@ LogicalResult LoopPipelinerInternal::createKernel(
515
516
516
517
// offset = (maxStage - stages[op]) * step
517
518
Type t = step.getType ();
518
- Value offset = rewriter. create < arith::MulIOp> (
519
- forOp.getLoc (), step,
520
- rewriter. create < arith::ConstantOp> (
521
- forOp.getLoc (),
519
+ Value offset = arith::MulIOp::create (
520
+ rewriter, forOp.getLoc (), step,
521
+ arith::ConstantOp::create (
522
+ rewriter, forOp.getLoc (),
522
523
rewriter.getIntegerAttr (t, maxStage - stages[op])));
523
- Value iv = rewriter. create < arith::AddIOp>(
524
- forOp. getLoc (), newForOp.getInductionVar (), offset);
524
+ Value iv = arith::AddIOp::create (rewriter, forOp. getLoc (),
525
+ newForOp.getInductionVar (), offset);
525
526
nestedNewOp->setOperand (operand->getOperandNumber (), iv);
526
527
rewriter.setInsertionPointAfter (newOp);
527
528
continue ;
@@ -594,8 +595,8 @@ LogicalResult LoopPipelinerInternal::createKernel(
594
595
auto defStage = stages.find (def);
595
596
if (defStage != stages.end () && defStage->second < maxStage) {
596
597
Value pred = predicates[defStage->second ];
597
- source = rewriter. create < arith::SelectOp> (
598
- pred.getLoc (), pred, source,
598
+ source = arith::SelectOp::create (
599
+ rewriter, pred.getLoc (), pred, source,
599
600
newForOp.getBody ()
600
601
->getArguments ()[yieldOperand.getOperandNumber () + 1 ]);
601
602
}
@@ -638,7 +639,7 @@ LogicalResult LoopPipelinerInternal::createKernel(
638
639
maxStage - defStage->second + 1 );
639
640
}
640
641
}
641
- rewriter. create < scf::YieldOp>( forOp.getLoc (), yieldOperands);
642
+ scf::YieldOp::create (rewriter, forOp.getLoc (), yieldOperands);
642
643
return success ();
643
644
}
644
645
@@ -652,51 +653,53 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
652
653
// removed by dead code if not used.
653
654
654
655
auto createConst = [&](int v) {
655
- return rewriter. create < arith::ConstantOp>( loc,
656
- rewriter.getIntegerAttr (t, v));
656
+ return arith::ConstantOp::create (rewriter, loc,
657
+ rewriter.getIntegerAttr (t, v));
657
658
};
658
659
659
660
// total_iterations = cdiv(range_diff, step);
660
661
// - range_diff = ub - lb
661
662
// - total_iterations = (range_diff + step + (step < 0 ? 1 : -1)) / step
662
663
Value zero = createConst (0 );
663
664
Value one = createConst (1 );
664
- Value stepLessZero = rewriter. create < arith::CmpIOp> (
665
- loc, arith::CmpIPredicate::slt, step, zero);
666
- Value stepDecr =
667
- rewriter. create <arith::SelectOp>(loc, stepLessZero, one, createConst (-1 ));
665
+ Value stepLessZero = arith::CmpIOp::create (
666
+ rewriter, loc, arith::CmpIPredicate::slt, step, zero);
667
+ Value stepDecr = arith::SelectOp::create (rewriter, loc, stepLessZero, one,
668
+ createConst (-1 ));
668
669
669
- Value rangeDiff = rewriter. create < arith::SubIOp>( loc, ub, lb);
670
- Value rangeIncrStep = rewriter. create < arith::AddIOp>( loc, rangeDiff, step);
670
+ Value rangeDiff = arith::SubIOp::create (rewriter, loc, ub, lb);
671
+ Value rangeIncrStep = arith::AddIOp::create (rewriter, loc, rangeDiff, step);
671
672
Value rangeDecr =
672
- rewriter.create <arith::AddIOp>(loc, rangeIncrStep, stepDecr);
673
- Value totalIterations = rewriter.create <arith::DivSIOp>(loc, rangeDecr, step);
673
+ arith::AddIOp::create (rewriter, loc, rangeIncrStep, stepDecr);
674
+ Value totalIterations =
675
+ arith::DivSIOp::create (rewriter, loc, rangeDecr, step);
674
676
675
677
// If total_iters < max_stage, start the epilogue at zero to match the
676
678
// ramp-up in the prologue.
677
679
// start_iter = max(0, total_iters - max_stage)
678
- Value iterI = rewriter. create < arith::SubIOp>( loc, totalIterations,
679
- createConst (maxStage));
680
- iterI = rewriter. create < arith::MaxSIOp>( loc, zero, iterI);
680
+ Value iterI = arith::SubIOp::create (rewriter, loc, totalIterations,
681
+ createConst (maxStage));
682
+ iterI = arith::MaxSIOp::create (rewriter, loc, zero, iterI);
681
683
682
684
// Capture predicates for dynamic loops.
683
685
SmallVector<Value> predicates (maxStage + 1 );
684
686
685
687
for (int64_t i = 1 ; i <= maxStage; i++) {
686
688
// newLastIter = lb + step * iterI
687
- Value newlastIter = rewriter. create < arith::AddIOp> (
688
- loc, lb, rewriter. create < arith::MulIOp>( loc, step, iterI));
689
+ Value newlastIter = arith::AddIOp::create (
690
+ rewriter, loc, lb, arith::MulIOp::create (rewriter, loc, step, iterI));
689
691
690
692
setValueMapping (forOp.getInductionVar (), newlastIter, i);
691
693
692
694
// increment to next iterI
693
- iterI = rewriter. create < arith::AddIOp>( loc, iterI, one);
695
+ iterI = arith::AddIOp::create (rewriter, loc, iterI, one);
694
696
695
697
if (dynamicLoop) {
696
698
// Disable stages when `i` is greater than total_iters.
697
699
// pred = total_iters >= i
698
- predicates[i] = rewriter.create <arith::CmpIOp>(
699
- loc, arith::CmpIPredicate::sge, totalIterations, createConst (i));
700
+ predicates[i] =
701
+ arith::CmpIOp::create (rewriter, loc, arith::CmpIPredicate::sge,
702
+ totalIterations, createConst (i));
700
703
}
701
704
}
702
705
@@ -758,8 +761,8 @@ LoopPipelinerInternal::emitEpilogue(RewriterBase &rewriter,
758
761
unsigned nextVersion = currentVersion + 1 ;
759
762
Value pred = predicates[currentVersion];
760
763
Value prevValue = valueMapping[mapVal][currentVersion];
761
- auto selOp = rewriter. create < arith::SelectOp>(loc, pred, pair. value () ,
762
- prevValue);
764
+ auto selOp = arith::SelectOp::create (rewriter, loc, pred ,
765
+ pair. value (), prevValue);
763
766
returnValues[ri] = selOp;
764
767
if (nextVersion <= maxStage)
765
768
setValueMapping (mapVal, selOp, nextVersion);
0 commit comments