-
Notifications
You must be signed in to change notification settings - Fork 14.5k
[mlir][OpenMP] Allow composite SIMD REDUCTION and IF #147568
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5370,58 +5370,90 @@ void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) { | |
|
||
void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop, | ||
Value *IfCond, ValueToValueMapTy &VMap, | ||
LoopAnalysis &LIA, LoopInfo &LI, Loop *L, | ||
const Twine &NamePrefix) { | ||
Function *F = CanonicalLoop->getFunction(); | ||
|
||
// Define where if branch should be inserted | ||
Instruction *SplitBefore = CanonicalLoop->getPreheader()->getTerminator(); | ||
|
||
// TODO: We should not rely on pass manager. Currently we use pass manager | ||
// only for getting llvm::Loop which corresponds to given CanonicalLoopInfo | ||
// object. We should have a method which returns all blocks between | ||
// CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter() | ||
FunctionAnalysisManager FAM; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to clarify: these passes are removed from createIfVersion because there is already a LoopAnalysis run in the calling function. Further I wanted to use the same There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the clarification |
||
FAM.registerPass([]() { return DominatorTreeAnalysis(); }); | ||
FAM.registerPass([]() { return LoopAnalysis(); }); | ||
FAM.registerPass([]() { return PassInstrumentationAnalysis(); }); | ||
// We can't do | ||
// if (cond) { | ||
// simd_loop; | ||
// } else { | ||
// non_simd_loop; | ||
// } | ||
// because then the CanonicalLoopInfo would only point to one of the loops: | ||
// leading to other constructs operating on the same loop to malfunction. | ||
// Instead generate | ||
// while (...) { | ||
// if (cond) { | ||
// simd_body; | ||
// } else { | ||
// not_simd_body; | ||
// } | ||
// } | ||
// At least for simple loops, LLVM seems able to hoist the if out of the loop | ||
// body at -O3 | ||
|
||
// Get the loop which needs to be cloned | ||
LoopAnalysis LIA; | ||
LoopInfo &&LI = LIA.run(*F, FAM); | ||
Loop *L = LI.getLoopFor(CanonicalLoop->getHeader()); | ||
// Define where if branch should be inserted | ||
auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt(); | ||
|
||
// Create additional blocks for the if statement | ||
BasicBlock *Head = SplitBefore->getParent(); | ||
Instruction *HeadOldTerm = Head->getTerminator(); | ||
llvm::LLVMContext &C = Head->getContext(); | ||
BasicBlock *Cond = SplitBeforeIt->getParent(); | ||
llvm::LLVMContext &C = Cond->getContext(); | ||
llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create( | ||
C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode()); | ||
C, NamePrefix + ".if.then", Cond->getParent(), Cond->getNextNode()); | ||
llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create( | ||
C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit()); | ||
C, NamePrefix + ".if.else", Cond->getParent(), CanonicalLoop->getExit()); | ||
|
||
// Create if condition branch. | ||
Builder.SetInsertPoint(HeadOldTerm); | ||
Builder.SetInsertPoint(SplitBeforeIt); | ||
Instruction *BrInstr = | ||
Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock); | ||
InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()}; | ||
// Then block contains branch to omp loop which needs to be vectorized | ||
// Then block contains branch to omp loop body which needs to be vectorized | ||
spliceBB(IP, ThenBlock, false, Builder.getCurrentDebugLocation()); | ||
ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock); | ||
ThenBlock->replaceSuccessorsPhiUsesWith(Cond, ThenBlock); | ||
|
||
Builder.SetInsertPoint(ElseBlock); | ||
|
||
// Clone loop for the else branch | ||
SmallVector<BasicBlock *, 8> NewBlocks; | ||
|
||
VMap[CanonicalLoop->getPreheader()] = ElseBlock; | ||
for (BasicBlock *Block : L->getBlocks()) { | ||
SmallVector<BasicBlock *, 8> ExistingBlocks; | ||
ExistingBlocks.reserve(L->getNumBlocks() + 1); | ||
ExistingBlocks.push_back(ThenBlock); | ||
ExistingBlocks.append(L->block_begin(), L->block_end()); | ||
// Cond is the block that has the if clause condition | ||
// LoopCond is omp_loop.cond | ||
// LoopHeader is omp_loop.header | ||
BasicBlock *LoopCond = Cond->getUniquePredecessor(); | ||
BasicBlock *LoopHeader = LoopCond->getUniquePredecessor(); | ||
assert(LoopCond && LoopHeader && "Invalid loop structure"); | ||
for (BasicBlock *Block : ExistingBlocks) { | ||
if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() || | ||
Block == LoopHeader || Block == LoopCond || Block == Cond) { | ||
continue; | ||
} | ||
BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F); | ||
|
||
// fix name not to be omp.if.then | ||
if (Block == ThenBlock) | ||
NewBB->setName(NamePrefix + ".if.else"); | ||
|
||
NewBB->moveBefore(CanonicalLoop->getExit()); | ||
VMap[Block] = NewBB; | ||
NewBlocks.push_back(NewBB); | ||
} | ||
remapInstructionsInBlocks(NewBlocks, VMap); | ||
Builder.CreateBr(NewBlocks.front()); | ||
|
||
// The loop latch must have only one predecessor. Currently it is branched to | ||
// from both the 'then' and 'else' branches. | ||
L->getLoopLatch()->splitBasicBlock( | ||
L->getLoopLatch()->begin(), NamePrefix + ".pre_latch", /*Before=*/true); | ||
|
||
// Ensure that the then block is added to the loop so we add the attributes in | ||
// the next step | ||
L->addBasicBlockToLoop(ThenBlock, LI); | ||
} | ||
|
||
unsigned | ||
|
@@ -5477,20 +5509,7 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, | |
|
||
if (IfCond) { | ||
ValueToValueMapTy VMap; | ||
createIfVersion(CanonicalLoop, IfCond, VMap, "simd"); | ||
// Add metadata to the cloned loop which disables vectorization | ||
Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch()); | ||
assert(MappedLatch && | ||
"Cannot find value which corresponds to original loop latch"); | ||
assert(isa<BasicBlock>(MappedLatch) && | ||
"Cannot cast mapped latch block value to BasicBlock"); | ||
BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch); | ||
ConstantAsMetadata *BoolConst = | ||
ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx))); | ||
addBasicBlockMetadata( | ||
NewLatchBlock, | ||
{MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"), | ||
BoolConst})}); | ||
createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd"); | ||
} | ||
|
||
SmallSet<BasicBlock *, 8> Reachable; | ||
|
@@ -5524,6 +5543,14 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop, | |
Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup})); | ||
} | ||
|
||
// FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD | ||
// versions so we can't add the loop attributes in that case. | ||
if (IfCond) { | ||
// we can still add llvm.loop.parallel_access | ||
addLoopMetadata(CanonicalLoop, LoopMDList); | ||
return; | ||
} | ||
|
||
// Use the above access group metadata to create loop level | ||
// metadata, which should be distinct for each loop. | ||
ConstantAsMetadata *BoolConst = | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s | ||
|
||
llvm.func @_QPfoo(%arg0: !llvm.ptr {fir.bindc_name = "array", llvm.nocapture}, %arg1: !llvm.ptr {fir.bindc_name = "t", llvm.nocapture}) { | ||
%0 = llvm.mlir.constant(0 : i64) : i32 | ||
%1 = llvm.mlir.constant(1 : i32) : i32 | ||
%2 = llvm.mlir.constant(10 : i64) : i64 | ||
%3 = llvm.mlir.constant(1 : i64) : i64 | ||
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr | ||
%5 = llvm.load %arg1 : !llvm.ptr -> i32 | ||
%6 = llvm.icmp "ne" %5, %0 : i32 | ||
%7 = llvm.trunc %2 : i64 to i32 | ||
omp.wsloop { | ||
omp.simd if(%6) { | ||
omp.loop_nest (%arg2) : i32 = (%1) to (%7) inclusive step (%1) { | ||
llvm.store %arg2, %4 : i32, !llvm.ptr | ||
%8 = llvm.load %4 : !llvm.ptr -> i32 | ||
%9 = llvm.sext %8 : i32 to i64 | ||
%10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr, i64) -> !llvm.ptr, i32 | ||
llvm.store %8, %10 : i32, !llvm.ptr | ||
omp.yield | ||
} | ||
} {omp.composite} | ||
} {omp.composite} | ||
llvm.return | ||
} | ||
|
||
// CHECK-LABEL: @_QPfoo | ||
// ... | ||
// CHECK: omp_loop.preheader: ; preds = | ||
// CHECK: store i32 0, ptr %[[LB_ADDR:.*]], align 4 | ||
// CHECK: store i32 9, ptr %[[UB_ADDR:.*]], align 4 | ||
// CHECK: store i32 1, ptr %[[STEP_ADDR:.*]], align 4 | ||
// CHECK: %[[VAL_15:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_15]], i32 34, ptr %{{.*}}, ptr %[[LB_ADDR]], ptr %[[UB_ADDR]], ptr %[[STEP_ADDR]], i32 1, i32 0) | ||
// CHECK: %[[LB:.*]] = load i32, ptr %[[LB_ADDR]], align 4 | ||
// CHECK: %[[UB:.*]] = load i32, ptr %[[UB_ADDR]], align 4 | ||
// CHECK: %[[VAL_18:.*]] = sub i32 %[[UB]], %[[LB]] | ||
// CHECK: %[[COUNT:.*]] = add i32 %[[VAL_18]], 1 | ||
// CHECK: br label %[[OMP_LOOP_HEADER:.*]] | ||
// CHECK: omp_loop.header: ; preds = %[[OMP_LOOP_INC:.*]], %[[OMP_LOOP_PREHEADER:.*]] | ||
// CHECK: %[[IV:.*]] = phi i32 [ 0, %[[OMP_LOOP_PREHEADER]] ], [ %[[NEW_IV:.*]], %[[OMP_LOOP_INC]] ] | ||
// CHECK: br label %[[OMP_LOOP_COND:.*]] | ||
// CHECK: omp_loop.cond: ; preds = %[[OMP_LOOP_HEADER]] | ||
// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[IV]], %[[COUNT]] | ||
// CHECK: br i1 %[[VAL_25]], label %[[OMP_LOOP_BODY:.*]], label %[[OMP_LOOP_EXIT:.*]] | ||
// CHECK: omp_loop.body: ; preds = %[[OMP_LOOP_COND]] | ||
// CHECK: %[[VAL_28:.*]] = add i32 %[[IV]], %[[LB]] | ||
// This is the IF clause: | ||
// CHECK: br i1 %{{.*}}, label %[[SIMD_IF_THEN:.*]], label %[[SIMD_IF_ELSE:.*]] | ||
|
||
// CHECK: simd.if.then: ; preds = %[[OMP_LOOP_BODY]] | ||
// CHECK: %[[VAL_29:.*]] = mul i32 %[[VAL_28]], 1 | ||
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1 | ||
// CHECK: br label %[[VAL_33:.*]] | ||
// CHECK: omp.loop_nest.region: ; preds = %[[SIMD_IF_THEN]] | ||
// This version contains !llvm.access.group metadata for SIMD | ||
// CHECK: store i32 %[[VAL_30]], ptr %{{.*}}, align 4, !llvm.access.group !1 | ||
// CHECK: %[[VAL_34:.*]] = load i32, ptr %{{.*}}, align 4, !llvm.access.group !1 | ||
// CHECK: %[[VAL_35:.*]] = sext i32 %[[VAL_34]] to i64 | ||
// CHECK: %[[VAL_36:.*]] = getelementptr i32, ptr %[[VAL_37:.*]], i64 %[[VAL_35]] | ||
// CHECK: store i32 %[[VAL_34]], ptr %[[VAL_36]], align 4, !llvm.access.group !1 | ||
// CHECK: br label %[[OMP_REGION_CONT3:.*]] | ||
// CHECK: omp.region.cont3: ; preds = %[[VAL_33]] | ||
// CHECK: br label %[[SIMD_PRE_LATCH:.*]] | ||
|
||
// CHECK: simd.pre_latch: ; preds = %[[OMP_REGION_CONT3]], %[[OMP_REGION_CONT35:.*]] | ||
// CHECK: br label %[[OMP_LOOP_INC]] | ||
// CHECK: omp_loop.inc: ; preds = %[[SIMD_PRE_LATCH]] | ||
// CHECK: %[[NEW_IV]] = add nuw i32 %[[IV]], 1 | ||
// CHECK: br label %[[OMP_LOOP_HEADER]], !llvm.loop !2 | ||
|
||
// CHECK: simd.if.else: ; preds = %[[OMP_LOOP_BODY]] | ||
// CHECK: br label %[[SIMD_IF_ELSE2:.*]] | ||
// CHECK: simd.if.else5: | ||
// CHECK: %[[MUL:.*]] = mul i32 %[[VAL_28]], 1 | ||
// CHECK: %[[ADD:.*]] = add i32 %[[MUL]], 1 | ||
// CHECK: br label %[[LOOP_NEST_REGION:.*]] | ||
// CHECK: omp.loop_nest.region6: ; preds = %[[SIMD_IF_ELSE2]] | ||
// No llvm.access.group metadata for else clause | ||
// CHECK: store i32 %[[ADD]], ptr %{{.*}}, align 4 | ||
// CHECK: %[[VAL_42:.*]] = load i32, ptr %{{.*}}, align 4 | ||
// CHECK: %[[VAL_43:.*]] = sext i32 %[[VAL_42]] to i64 | ||
// CHECK: %[[VAL_44:.*]] = getelementptr i32, ptr %[[VAL_37]], i64 %[[VAL_43]] | ||
// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_44]], align 4 | ||
// CHECK: br label %[[OMP_REGION_CONT35]] | ||
// CHECK: omp.region.cont37: ; preds = %[[LOOP_NEST_REGION]] | ||
// CHECK: br label %[[SIMD_PRE_LATCH]] | ||
|
||
// CHECK: omp_loop.exit: ; preds = %[[OMP_LOOP_COND]] | ||
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_15]]) | ||
// CHECK: %[[VAL_45:.*]] = call i32 @__kmpc_global_thread_num(ptr @1) | ||
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_45]]) | ||
|
||
// CHECK: !1 = distinct !{} | ||
// CHECK: !2 = distinct !{!2, !3} | ||
// CHECK: !3 = !{!"llvm.loop.parallel_accesses", !1} | ||
// CHECK-NOT: llvm.loop.vectorize |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit:
llvm::
inllvm::Loop
is probably redundant here.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesn't like me using just
Loop
because of the first parameter which is calledLoop
.