Skip to content

Commit 18b4095

Browse files
authored
[mlir] [scf-to-cf] attach the loop annotation to latch block (#147462)
As [required by LLVM](https://llvm.org/docs/LangRef.html#llvm-loop), the loop annotation (loop metadata) should be attached on the ["latch" block](https://llvm.org/docs/LoopTerminology.html). Otherwise, the annotation might be ignored by LLVM. This PR fixes this issue.
1 parent cc95e40 commit 18b4095

File tree

2 files changed

+35
-18
lines changed

2 files changed

+35
-18
lines changed

mlir/lib/Conversion/SCFToControlFlow/SCFToControlFlow.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,7 +347,20 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
347347
SmallVector<Value, 8> loopCarried;
348348
loopCarried.push_back(stepped);
349349
loopCarried.append(terminator->operand_begin(), terminator->operand_end());
350-
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
350+
auto branchOp =
351+
rewriter.create<cf::BranchOp>(loc, conditionBlock, loopCarried);
352+
353+
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
354+
// llvm.loop_annotation attribute.
355+
// LLVM requires the loop metadata to be attached on the "latch" block. Which
356+
// is the back-edge to the header block (conditionBlock)
357+
SmallVector<NamedAttribute> llvmAttrs;
358+
llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
359+
[](auto attr) {
360+
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
361+
});
362+
branchOp->setDiscardableAttrs(llvmAttrs);
363+
351364
rewriter.eraseOp(terminator);
352365

353366
// Compute loop bounds before branching to the condition.
@@ -369,18 +382,10 @@ LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
369382
auto comparison = rewriter.create<arith::CmpIOp>(
370383
loc, arith::CmpIPredicate::slt, iv, upperBound);
371384

372-
auto condBranchOp = rewriter.create<cf::CondBranchOp>(
373-
loc, comparison, firstBodyBlock, ArrayRef<Value>(), endBlock,
374-
ArrayRef<Value>());
385+
rewriter.create<cf::CondBranchOp>(loc, comparison, firstBodyBlock,
386+
ArrayRef<Value>(), endBlock,
387+
ArrayRef<Value>());
375388

376-
// Let the CondBranchOp carry the LLVM attributes from the ForOp, such as the
377-
// llvm.loop_annotation attribute.
378-
SmallVector<NamedAttribute> llvmAttrs;
379-
llvm::copy_if(forOp->getAttrs(), std::back_inserter(llvmAttrs),
380-
[](auto attr) {
381-
return isa<LLVM::LLVMDialect>(attr.getValue().getDialect());
382-
});
383-
condBranchOp->setDiscardableAttrs(llvmAttrs);
384389
// The result of the loop operation is the values of the condition block
385390
// arguments except the induction variable on the last iteration.
386391
rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());

mlir/test/Conversion/SCFToControlFlow/convert-to-cfg.mlir

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -678,12 +678,24 @@ func.func @forall(%num_threads: index) {
678678

679679
// -----
680680

681-
// CHECK: #loop_unroll = #llvm.loop_unroll<disable = true>
682-
// CHECK-NEXT: #loop_unroll1 = #llvm.loop_unroll<full = true>
683-
// CHECK-NEXT: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #loop_unroll>
684-
// CHECK-NEXT: #[[FULL_UNROLL:.*]] = #llvm.loop_annotation<unroll = #loop_unroll1>
685-
// CHECK: cf.cond_br %{{.*}}, ^bb2, ^bb6 {llvm.loop_annotation = #[[NO_UNROLL]]}
686-
// CHECK: cf.cond_br %{{.*}}, ^bb4, ^bb5 {llvm.loop_annotation = #[[FULL_UNROLL]]}
681+
// CHECK: #[[LOOP_UNROLL:.*]] = #llvm.loop_unroll<full = true>
682+
// CHECK-DAG: #[[LOOP_UNROLL_DISABLE:.*]] = #llvm.loop_unroll<disable = true>
683+
684+
// CHECK-DAG: #[[FULL_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL]]>
685+
// CHECK-DAG: #[[NO_UNROLL:.*]] = #llvm.loop_annotation<unroll = #[[LOOP_UNROLL_DISABLE]]>
686+
// CHECK: func @simple_std_for_loops_annotation
687+
// CHECK: ^[[bb1:.*]](%{{.*}}: index):
688+
// CHECK: cf.cond_br %{{.*}}, ^[[bb2:.*]], ^[[bb6:.*]]
689+
// CHECK: ^[[bb2]]:
690+
// CHECK: cf.br ^[[bb3:.*]]({{.*}})
691+
// CHECK: ^[[bb3]](%{{.*}}: index):
692+
// CHECK: cf.cond_br %{{.*}}, ^[[bb4:.*]], ^[[bb5:.*]]
693+
// CHECK: ^[[bb4]]:
694+
// CHECK: cf.br ^[[bb3]]({{.*}}) {llvm.loop_annotation = #[[FULL_UNROLL]]}
695+
// CHECK: ^[[bb5]]:
696+
// CHECK: cf.br ^[[bb1]]({{.*}}) {llvm.loop_annotation = #[[NO_UNROLL]]}
697+
// CHECK: ^[[bb6]]:
698+
// CHECK: return
687699
#no_unroll = #llvm.loop_annotation<unroll = <disable = true>>
688700
#full_unroll = #llvm.loop_annotation<unroll = <full = true>>
689701
func.func @simple_std_for_loops_annotation(%arg0 : index, %arg1 : index, %arg2 : index) {

0 commit comments

Comments
 (0)