Skip to content

[mlir][OpenMP] Allow composite SIMD REDUCTION and IF #147568

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ class CanonicalLoopInfo;
struct TargetRegionEntryInfo;
class OffloadEntriesInfoManager;
class OpenMPIRBuilder;
class Loop;
class LoopAnalysis;
class LoopInfo;

/// Move the instruction after an InsertPoint to the beginning of another
/// BasicBlock.
Expand Down Expand Up @@ -1114,6 +1117,7 @@ class OpenMPIRBuilder {
/// \param NamePrefix Optional name prefix for if.then if.else blocks.
void createIfVersion(CanonicalLoopInfo *Loop, Value *IfCond,
ValueMap<const Value *, WeakTrackingVH> &VMap,
LoopAnalysis &LIA, LoopInfo &LI, llvm::Loop *L,
const Twine &NamePrefix = "");

public:
Expand Down
105 changes: 66 additions & 39 deletions llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5370,58 +5370,90 @@ void OpenMPIRBuilder::unrollLoopHeuristic(DebugLoc, CanonicalLoopInfo *Loop) {

void OpenMPIRBuilder::createIfVersion(CanonicalLoopInfo *CanonicalLoop,
Value *IfCond, ValueToValueMapTy &VMap,
LoopAnalysis &LIA, LoopInfo &LI, Loop *L,
const Twine &NamePrefix) {
Function *F = CanonicalLoop->getFunction();

// Define where if branch should be inserted
Instruction *SplitBefore = CanonicalLoop->getPreheader()->getTerminator();

// TODO: We should not rely on pass manager. Currently we use pass manager
// only for getting llvm::Loop which corresponds to given CanonicalLoopInfo
// object. We should have a method which returns all blocks between
// CanonicalLoopInfo::getHeader() and CanonicalLoopInfo::getAfter()
FunctionAnalysisManager FAM;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to clarify: these passes are removed from createIfVersion because there is already a LoopAnalysis run in the calling function.

Further I wanted to use the same Loop instance in both so that I could add the new block to it.

FAM.registerPass([]() { return DominatorTreeAnalysis(); });
FAM.registerPass([]() { return LoopAnalysis(); });
FAM.registerPass([]() { return PassInstrumentationAnalysis(); });
// We can't do
// if (cond) {
// simd_loop;
// } else {
// non_simd_loop;
// }
// because then the CanonicalLoopInfo would only point to one of the loops:
// leading to other constructs operating on the same loop to malfunction.
// Instead generate
// while (...) {
// if (cond) {
// simd_body;
// } else {
// not_simd_body;
// }
// }
// At least for simple loops, LLVM seems able to hoist the if out of the loop
// body at -O3

// Get the loop which needs to be cloned
LoopAnalysis LIA;
LoopInfo &&LI = LIA.run(*F, FAM);
Loop *L = LI.getLoopFor(CanonicalLoop->getHeader());
// Define where if branch should be inserted
auto SplitBeforeIt = CanonicalLoop->getBody()->getFirstNonPHIIt();

// Create additional blocks for the if statement
BasicBlock *Head = SplitBefore->getParent();
Instruction *HeadOldTerm = Head->getTerminator();
llvm::LLVMContext &C = Head->getContext();
BasicBlock *Cond = SplitBeforeIt->getParent();
llvm::LLVMContext &C = Cond->getContext();
llvm::BasicBlock *ThenBlock = llvm::BasicBlock::Create(
C, NamePrefix + ".if.then", Head->getParent(), Head->getNextNode());
C, NamePrefix + ".if.then", Cond->getParent(), Cond->getNextNode());
llvm::BasicBlock *ElseBlock = llvm::BasicBlock::Create(
C, NamePrefix + ".if.else", Head->getParent(), CanonicalLoop->getExit());
C, NamePrefix + ".if.else", Cond->getParent(), CanonicalLoop->getExit());

// Create if condition branch.
Builder.SetInsertPoint(HeadOldTerm);
Builder.SetInsertPoint(SplitBeforeIt);
Instruction *BrInstr =
Builder.CreateCondBr(IfCond, ThenBlock, /*ifFalse*/ ElseBlock);
InsertPointTy IP{BrInstr->getParent(), ++BrInstr->getIterator()};
// Then block contains branch to omp loop which needs to be vectorized
// Then block contains branch to omp loop body which needs to be vectorized
spliceBB(IP, ThenBlock, false, Builder.getCurrentDebugLocation());
ThenBlock->replaceSuccessorsPhiUsesWith(Head, ThenBlock);
ThenBlock->replaceSuccessorsPhiUsesWith(Cond, ThenBlock);

Builder.SetInsertPoint(ElseBlock);

// Clone loop for the else branch
SmallVector<BasicBlock *, 8> NewBlocks;

VMap[CanonicalLoop->getPreheader()] = ElseBlock;
for (BasicBlock *Block : L->getBlocks()) {
SmallVector<BasicBlock *, 8> ExistingBlocks;
ExistingBlocks.reserve(L->getNumBlocks() + 1);
ExistingBlocks.push_back(ThenBlock);
ExistingBlocks.append(L->block_begin(), L->block_end());
// Cond is the block that has the if clause condition
// LoopCond is omp_loop.cond
// LoopHeader is omp_loop.header
BasicBlock *LoopCond = Cond->getUniquePredecessor();
BasicBlock *LoopHeader = LoopCond->getUniquePredecessor();
assert(LoopCond && LoopHeader && "Invalid loop structure");
for (BasicBlock *Block : ExistingBlocks) {
if (Block == L->getLoopPreheader() || Block == L->getLoopLatch() ||
Block == LoopHeader || Block == LoopCond || Block == Cond) {
continue;
}
BasicBlock *NewBB = CloneBasicBlock(Block, VMap, "", F);

// fix name not to be omp.if.then
if (Block == ThenBlock)
NewBB->setName(NamePrefix + ".if.else");

NewBB->moveBefore(CanonicalLoop->getExit());
VMap[Block] = NewBB;
NewBlocks.push_back(NewBB);
}
remapInstructionsInBlocks(NewBlocks, VMap);
Builder.CreateBr(NewBlocks.front());

// The loop latch must have only one predecessor. Currently it is branched to
// from both the 'then' and 'else' branches.
L->getLoopLatch()->splitBasicBlock(
L->getLoopLatch()->begin(), NamePrefix + ".pre_latch", /*Before=*/true);

// Ensure that the then block is added to the loop so we add the attributes in
// the next step
L->addBasicBlockToLoop(ThenBlock, LI);
}

unsigned
Expand Down Expand Up @@ -5477,20 +5509,7 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,

if (IfCond) {
ValueToValueMapTy VMap;
createIfVersion(CanonicalLoop, IfCond, VMap, "simd");
// Add metadata to the cloned loop which disables vectorization
Value *MappedLatch = VMap.lookup(CanonicalLoop->getLatch());
assert(MappedLatch &&
"Cannot find value which corresponds to original loop latch");
assert(isa<BasicBlock>(MappedLatch) &&
"Cannot cast mapped latch block value to BasicBlock");
BasicBlock *NewLatchBlock = dyn_cast<BasicBlock>(MappedLatch);
ConstantAsMetadata *BoolConst =
ConstantAsMetadata::get(ConstantInt::getFalse(Type::getInt1Ty(Ctx)));
addBasicBlockMetadata(
NewLatchBlock,
{MDNode::get(Ctx, {MDString::get(Ctx, "llvm.loop.vectorize.enable"),
BoolConst})});
createIfVersion(CanonicalLoop, IfCond, VMap, LIA, LI, L, "simd");
}

SmallSet<BasicBlock *, 8> Reachable;
Expand Down Expand Up @@ -5524,6 +5543,14 @@ void OpenMPIRBuilder::applySimd(CanonicalLoopInfo *CanonicalLoop,
Ctx, {MDString::get(Ctx, "llvm.loop.parallel_accesses"), AccessGroup}));
}

// FIXME: the IF clause shares a loop backedge for the SIMD and non-SIMD
// versions so we can't add the loop attributes in that case.
if (IfCond) {
// we can still add llvm.loop.parallel_access
addLoopMetadata(CanonicalLoop, LoopMDList);
return;
}

// Use the above access group metadata to create loop level
// metadata, which should be distinct for each loop.
ConstantAsMetadata *BoolConst =
Expand Down
33 changes: 22 additions & 11 deletions llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2242,23 +2242,34 @@ TEST_F(OpenMPIRBuilderTest, ApplySimdIf) {
PB.registerFunctionAnalyses(FAM);
LoopInfo &LI = FAM.getResult<LoopAnalysis>(*F);

// Check if there are two loops (one with enabled vectorization)
// Check if there is one loop containing branches with and without
// vectorization
const std::vector<Loop *> &TopLvl = LI.getTopLevelLoops();
EXPECT_EQ(TopLvl.size(), 2u);
EXPECT_EQ(TopLvl.size(), 1u);

Loop *L = TopLvl[0];
EXPECT_TRUE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 3);

// The second loop should have disabled vectorization
L = TopLvl[1];
EXPECT_FALSE(findStringMetadataForLoop(L, "llvm.loop.parallel_accesses"));
// These attributes cannot not be set because the loop is shared between simd
// and non-simd versions
EXPECT_FALSE(getBooleanLoopAttribute(L, "llvm.loop.vectorize.enable"));
// Check for llvm.access.group metadata attached to the printf
// function in the loop body.
EXPECT_EQ(getIntLoopAttribute(L, "llvm.loop.vectorize.width"), 0);

// Check for if condition
BasicBlock *LoopBody = CLI->getBody();
EXPECT_TRUE(any_of(*LoopBody, [](Instruction &I) {
BranchInst *IfCond = cast<BranchInst>(LoopBody->getTerminator());
EXPECT_EQ(IfCond->getCondition(), IfCmp);
BasicBlock *TrueBranch = IfCond->getSuccessor(0);
BasicBlock *FalseBranch = IfCond->getSuccessor(1)->getUniqueSuccessor();

// Check for llvm.access.group metadata attached to the printf
// function in the true body.
EXPECT_TRUE(any_of(*TrueBranch, [](Instruction &I) {
return I.getMetadata("llvm.access.group") != nullptr;
}));

// Check for llvm.access.group metadata attached to the printf
// function in the false body.
EXPECT_FALSE(any_of(*FalseBranch, [](Instruction &I) {
return I.getMetadata("llvm.access.group") != nullptr;
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -702,30 +702,6 @@ static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
}

/// Helper function to map block arguments defined by ignored loop wrappers to
/// LLVM values and prevent any uses of those from triggering null pointer
/// dereferences.
///
/// This must be called after block arguments of parent wrappers have already
/// been mapped to LLVM IR values.
static LogicalResult
convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
LLVM::ModuleTranslation &moduleTranslation) {
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
// being present.
return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](omp::SimdOp op) {
forwardArgs(moduleTranslation,
cast<omp::BlockArgOpenMPOpInterface>(*op));
op.emitWarning() << "simd information on composite construct discarded";
return success();
})
.Default([&](Operation *op) {
return op->emitError() << "cannot ignore wrapper";
});
}

/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
Expand Down Expand Up @@ -2852,17 +2828,6 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto simdOp = cast<omp::SimdOp>(opInst);

// Ignore simd in composite constructs with unsupported clauses
// TODO: Replace this once simd + clause combinations are properly supported
if (simdOp.isComposite() &&
(simdOp.getReductionByref().has_value() || simdOp.getIfExpr())) {
if (failed(convertIgnoredWrapper(simdOp, moduleTranslation)))
return failure();

return inlineConvertOmpRegions(simdOp.getRegion(), "omp.simd.region",
builder, moduleTranslation);
}

if (failed(checkImplementationStatus(opInst)))
return failure();

Expand Down
97 changes: 97 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-composite-simd-if.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

llvm.func @_QPfoo(%arg0: !llvm.ptr {fir.bindc_name = "array", llvm.nocapture}, %arg1: !llvm.ptr {fir.bindc_name = "t", llvm.nocapture}) {
%0 = llvm.mlir.constant(0 : i64) : i32
%1 = llvm.mlir.constant(1 : i32) : i32
%2 = llvm.mlir.constant(10 : i64) : i64
%3 = llvm.mlir.constant(1 : i64) : i64
%4 = llvm.alloca %3 x i32 {bindc_name = "i", pinned} : (i64) -> !llvm.ptr
%5 = llvm.load %arg1 : !llvm.ptr -> i32
%6 = llvm.icmp "ne" %5, %0 : i32
%7 = llvm.trunc %2 : i64 to i32
omp.wsloop {
omp.simd if(%6) {
omp.loop_nest (%arg2) : i32 = (%1) to (%7) inclusive step (%1) {
llvm.store %arg2, %4 : i32, !llvm.ptr
%8 = llvm.load %4 : !llvm.ptr -> i32
%9 = llvm.sext %8 : i32 to i64
%10 = llvm.getelementptr %arg0[%9] : (!llvm.ptr, i64) -> !llvm.ptr, i32
llvm.store %8, %10 : i32, !llvm.ptr
omp.yield
}
} {omp.composite}
} {omp.composite}
llvm.return
}

// CHECK-LABEL: @_QPfoo
// ...
// CHECK: omp_loop.preheader: ; preds =
// CHECK: store i32 0, ptr %[[LB_ADDR:.*]], align 4
// CHECK: store i32 9, ptr %[[UB_ADDR:.*]], align 4
// CHECK: store i32 1, ptr %[[STEP_ADDR:.*]], align 4
// CHECK: %[[VAL_15:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_for_static_init_4u(ptr @1, i32 %[[VAL_15]], i32 34, ptr %{{.*}}, ptr %[[LB_ADDR]], ptr %[[UB_ADDR]], ptr %[[STEP_ADDR]], i32 1, i32 0)
// CHECK: %[[LB:.*]] = load i32, ptr %[[LB_ADDR]], align 4
// CHECK: %[[UB:.*]] = load i32, ptr %[[UB_ADDR]], align 4
// CHECK: %[[VAL_18:.*]] = sub i32 %[[UB]], %[[LB]]
// CHECK: %[[COUNT:.*]] = add i32 %[[VAL_18]], 1
// CHECK: br label %[[OMP_LOOP_HEADER:.*]]
// CHECK: omp_loop.header: ; preds = %[[OMP_LOOP_INC:.*]], %[[OMP_LOOP_PREHEADER:.*]]
// CHECK: %[[IV:.*]] = phi i32 [ 0, %[[OMP_LOOP_PREHEADER]] ], [ %[[NEW_IV:.*]], %[[OMP_LOOP_INC]] ]
// CHECK: br label %[[OMP_LOOP_COND:.*]]
// CHECK: omp_loop.cond: ; preds = %[[OMP_LOOP_HEADER]]
// CHECK: %[[VAL_25:.*]] = icmp ult i32 %[[IV]], %[[COUNT]]
// CHECK: br i1 %[[VAL_25]], label %[[OMP_LOOP_BODY:.*]], label %[[OMP_LOOP_EXIT:.*]]
// CHECK: omp_loop.body: ; preds = %[[OMP_LOOP_COND]]
// CHECK: %[[VAL_28:.*]] = add i32 %[[IV]], %[[LB]]
// This is the IF clause:
// CHECK: br i1 %{{.*}}, label %[[SIMD_IF_THEN:.*]], label %[[SIMD_IF_ELSE:.*]]

// CHECK: simd.if.then: ; preds = %[[OMP_LOOP_BODY]]
// CHECK: %[[VAL_29:.*]] = mul i32 %[[VAL_28]], 1
// CHECK: %[[VAL_30:.*]] = add i32 %[[VAL_29]], 1
// CHECK: br label %[[VAL_33:.*]]
// CHECK: omp.loop_nest.region: ; preds = %[[SIMD_IF_THEN]]
// This version contains !llvm.access.group metadata for SIMD
// CHECK: store i32 %[[VAL_30]], ptr %{{.*}}, align 4, !llvm.access.group !1
// CHECK: %[[VAL_34:.*]] = load i32, ptr %{{.*}}, align 4, !llvm.access.group !1
// CHECK: %[[VAL_35:.*]] = sext i32 %[[VAL_34]] to i64
// CHECK: %[[VAL_36:.*]] = getelementptr i32, ptr %[[VAL_37:.*]], i64 %[[VAL_35]]
// CHECK: store i32 %[[VAL_34]], ptr %[[VAL_36]], align 4, !llvm.access.group !1
// CHECK: br label %[[OMP_REGION_CONT3:.*]]
// CHECK: omp.region.cont3: ; preds = %[[VAL_33]]
// CHECK: br label %[[SIMD_PRE_LATCH:.*]]

// CHECK: simd.pre_latch: ; preds = %[[OMP_REGION_CONT3]], %[[OMP_REGION_CONT35:.*]]
// CHECK: br label %[[OMP_LOOP_INC]]
// CHECK: omp_loop.inc: ; preds = %[[SIMD_PRE_LATCH]]
// CHECK: %[[NEW_IV]] = add nuw i32 %[[IV]], 1
// CHECK: br label %[[OMP_LOOP_HEADER]], !llvm.loop !2

// CHECK: simd.if.else: ; preds = %[[OMP_LOOP_BODY]]
// CHECK: br label %[[SIMD_IF_ELSE2:.*]]
// CHECK: simd.if.else5:
// CHECK: %[[MUL:.*]] = mul i32 %[[VAL_28]], 1
// CHECK: %[[ADD:.*]] = add i32 %[[MUL]], 1
// CHECK: br label %[[LOOP_NEST_REGION:.*]]
// CHECK: omp.loop_nest.region6: ; preds = %[[SIMD_IF_ELSE2]]
// No llvm.access.group metadata for else clause
// CHECK: store i32 %[[ADD]], ptr %{{.*}}, align 4
// CHECK: %[[VAL_42:.*]] = load i32, ptr %{{.*}}, align 4
// CHECK: %[[VAL_43:.*]] = sext i32 %[[VAL_42]] to i64
// CHECK: %[[VAL_44:.*]] = getelementptr i32, ptr %[[VAL_37]], i64 %[[VAL_43]]
// CHECK: store i32 %[[VAL_42]], ptr %[[VAL_44]], align 4
// CHECK: br label %[[OMP_REGION_CONT35]]
// CHECK: omp.region.cont37: ; preds = %[[LOOP_NEST_REGION]]
// CHECK: br label %[[SIMD_PRE_LATCH]]

// CHECK: omp_loop.exit: ; preds = %[[OMP_LOOP_COND]]
// CHECK: call void @__kmpc_for_static_fini(ptr @1, i32 %[[VAL_15]])
// CHECK: %[[VAL_45:.*]] = call i32 @__kmpc_global_thread_num(ptr @1)
// CHECK: call void @__kmpc_barrier(ptr @2, i32 %[[VAL_45]])

// CHECK: !1 = distinct !{}
// CHECK: !2 = distinct !{!2, !3}
// CHECK: !3 = !{!"llvm.loop.parallel_accesses", !1}
// CHECK-NOT: llvm.loop.vectorize
2 changes: 0 additions & 2 deletions mlir/test/Target/LLVMIR/openmp-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -820,8 +820,6 @@ llvm.func @simd_if(%arg0: !llvm.ptr {fir.bindc_name = "n"}, %arg1: !llvm.ptr {fi
}
// Be sure that llvm.loop.vectorize.enable metadata appears twice
// CHECK: llvm.loop.parallel_accesses
// CHECK-NEXT: llvm.loop.vectorize.enable
// CHECK: llvm.loop.vectorize.enable

// -----

Expand Down
9 changes: 6 additions & 3 deletions mlir/test/Target/LLVMIR/openmp-reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -637,9 +637,12 @@ llvm.func @wsloop_simd_reduction(%lb : i64, %ub : i64, %step : i64) {
// Outlined function.
// CHECK: define internal void @[[OUTLINED]]

// Private reduction variable and its initialization.
// reduction variable in wsloop
// CHECK: %[[PRIVATE:.+]] = alloca float
// reduction variable in simd
// CHECK: %[[PRIVATE2:.+]] = alloca float
// CHECK: store float 0.000000e+00, ptr %[[PRIVATE]]
// CHECK: store float 0.000000e+00, ptr %[[PRIVATE2]]

// Call to the reduction function.
// CHECK: call i32 @__kmpc_reduce
Expand All @@ -659,9 +662,9 @@ llvm.func @wsloop_simd_reduction(%lb : i64, %ub : i64, %step : i64) {

// Update of the private variable using the reduction region
// (the body block currently comes after all the other blocks).
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE]]
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE2]]
// CHECK: %[[UPDATED:.+]] = fadd float 2.000000e+00, %[[PARTIAL]]
// CHECK: store float %[[UPDATED]], ptr %[[PRIVATE]]
// CHECK: store float %[[UPDATED]], ptr %[[PRIVATE2]]

// Reduction function.
// CHECK: define internal void @[[REDFUNC]]
Expand Down
Loading