Skip to content

Commit 96bc07d

Browse files
authored
[MLIR][OpenMP] Add canonical loop LLVM-IR lowering (#147069)
Support for translating the operations introduced in #144785 to LLVM-IR. In order to keep the lowering simple, `OpenMPIRBuider::unrollLoopHeuristic` is applied when encountering the `omp.unroll_heuristic` op. As a result, the operation that unrolling is applied to (`omp.canonical_loop`) must have been emitted before even though logically there is no such requirement. Eventually, all transformations on a loop must be applied directly after emitting `omp.canonical_loop`, i.e. future transformations must be looked-up when encountering `omp.canonical_loop` itself. This is because many OpenMPIRBuilder methods (e.g. `createParallel`) expect all the region code to be emitted withing a callback. In the case of `createParallel`, the region code is getting outlined into a new function. Therefore, making the operation order a formal requirement would not make the implementation any easier.
1 parent a61ea9f commit 96bc07d

File tree

6 files changed

+455
-0
lines changed

6 files changed

+455
-0
lines changed

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
1616

1717
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
18+
#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
1819
#include "mlir/IR/Operation.h"
1920
#include "mlir/IR/SymbolTable.h"
2021
#include "mlir/IR/Value.h"
@@ -24,6 +25,7 @@
2425
#include "mlir/Target/LLVMIR/TypeToLLVM.h"
2526

2627
#include "llvm/ADT/SetVector.h"
28+
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
2729
#include "llvm/IR/FPEnv.h"
2830

2931
namespace llvm {
@@ -108,6 +110,41 @@ class ModuleTranslation {
108110
return blockMapping.lookup(block);
109111
}
110112

113+
/// Find the LLVM-IR loop that represents an MLIR loop.
114+
llvm::CanonicalLoopInfo *lookupOMPLoop(omp::NewCliOp mlir) const {
115+
llvm::CanonicalLoopInfo *result = loopMapping.lookup(mlir);
116+
assert(result && "attempt to get non-existing loop");
117+
return result;
118+
}
119+
120+
/// Find the LLVM-IR loop that represents an MLIR loop.
121+
llvm::CanonicalLoopInfo *lookupOMPLoop(Value mlir) const {
122+
return lookupOMPLoop(mlir.getDefiningOp<omp::NewCliOp>());
123+
}
124+
125+
/// Mark an OpenMP loop as having been consumed.
126+
void invalidateOmpLoop(omp::NewCliOp mlir) { loopMapping.erase(mlir); }
127+
128+
/// Mark an OpenMP loop as having been consumed.
129+
void invalidateOmpLoop(Value mlir) {
130+
invalidateOmpLoop(mlir.getDefiningOp<omp::NewCliOp>());
131+
}
132+
133+
/// Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR
134+
/// OpenMPIRBuilder CanonicalLoopInfo
135+
void mapOmpLoop(omp::NewCliOp mlir, llvm::CanonicalLoopInfo *llvm) {
136+
assert(llvm && "argument must be non-null");
137+
llvm::CanonicalLoopInfo *&cur = loopMapping[mlir];
138+
assert(cur == nullptr && "attempting to map a loop that is already mapped");
139+
cur = llvm;
140+
}
141+
142+
/// Map an MLIR OpenMP dialect CanonicalLoopInfo to its lowered LLVM-IR
143+
/// OpenMPIRBuilder CanonicalLoopInfo
144+
void mapOmpLoop(Value mlir, llvm::CanonicalLoopInfo *llvm) {
145+
mapOmpLoop(mlir.getDefiningOp<omp::NewCliOp>(), llvm);
146+
}
147+
111148
/// Stores the mapping between an MLIR operation with successors and a
112149
/// corresponding LLVM IR instruction.
113150
void mapBranch(Operation *mlir, llvm::Instruction *llvm) {
@@ -381,6 +418,12 @@ class ModuleTranslation {
381418
DenseMap<Value, llvm::Value *> valueMapping;
382419
DenseMap<Block *, llvm::BasicBlock *> blockMapping;
383420

421+
/// List of not yet consumed MLIR loop handles (represented by an omp.new_cli
422+
/// operation which creates a value of type CanonicalLoopInfoType) and their
423+
/// LLVM-IR representation as CanonicalLoopInfo which is managed by the
424+
/// OpenMPIRBuilder.
425+
DenseMap<omp::NewCliOp, llvm::CanonicalLoopInfo *> loopMapping;
426+
384427
/// A mapping between MLIR LLVM dialect terminators and LLVM IR terminators
385428
/// they are converted to. This allows for connecting PHI nodes to the source
386429
/// values after all operations are converted.

mlir/lib/Conversion/OpenMPToLLVM/OpenMPToLLVM.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,16 @@ template <typename T>
4141
struct OpenMPOpConversion : public ConvertOpToLLVMPattern<T> {
4242
using ConvertOpToLLVMPattern<T>::ConvertOpToLLVMPattern;
4343

44+
OpenMPOpConversion(LLVMTypeConverter &typeConverter,
45+
PatternBenefit benefit = 1)
46+
: ConvertOpToLLVMPattern<T>(typeConverter, benefit) {
47+
// Operations using CanonicalLoopInfoType are lowered only by
48+
// mlir::translateModuleToLLVMIR() using the OpenMPIRBuilder. Until then,
49+
// the type and operations using it must be preserved.
50+
typeConverter.addConversion(
51+
[&](::mlir::omp::CanonicalLoopInfoType type) { return type; });
52+
}
53+
4454
LogicalResult
4555
matchAndRewrite(T op, typename T::Adaptor adaptor,
4656
ConversionPatternRewriter &rewriter) const override {

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3095,6 +3095,67 @@ convertOmpLoopNest(Operation &opInst, llvm::IRBuilderBase &builder,
30953095
return success();
30963096
}
30973097

3098+
/// Convert an omp.canonical_loop to LLVM-IR
3099+
static LogicalResult
3100+
convertOmpCanonicalLoopOp(omp::CanonicalLoopOp op, llvm::IRBuilderBase &builder,
3101+
LLVM::ModuleTranslation &moduleTranslation) {
3102+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3103+
3104+
llvm::OpenMPIRBuilder::LocationDescription loopLoc(builder);
3105+
Value loopIV = op.getInductionVar();
3106+
Value loopTC = op.getTripCount();
3107+
3108+
llvm::Value *llvmTC = moduleTranslation.lookupValue(loopTC);
3109+
3110+
llvm::Expected<llvm::CanonicalLoopInfo *> llvmOrError =
3111+
ompBuilder->createCanonicalLoop(
3112+
loopLoc,
3113+
[&](llvm::OpenMPIRBuilder::InsertPointTy ip, llvm::Value *llvmIV) {
3114+
// Register the mapping of MLIR induction variable to LLVM-IR
3115+
// induction variable
3116+
moduleTranslation.mapValue(loopIV, llvmIV);
3117+
3118+
builder.restoreIP(ip);
3119+
llvm::Expected<llvm::BasicBlock *> bodyGenStatus =
3120+
convertOmpOpRegions(op.getRegion(), "omp.loop.region", builder,
3121+
moduleTranslation);
3122+
3123+
return bodyGenStatus.takeError();
3124+
},
3125+
llvmTC, "omp.loop");
3126+
if (!llvmOrError)
3127+
return op.emitError(llvm::toString(llvmOrError.takeError()));
3128+
3129+
llvm::CanonicalLoopInfo *llvmCLI = *llvmOrError;
3130+
llvm::IRBuilderBase::InsertPoint afterIP = llvmCLI->getAfterIP();
3131+
builder.restoreIP(afterIP);
3132+
3133+
// Register the mapping of MLIR loop to LLVM-IR OpenMPIRBuilder loop
3134+
if (Value cli = op.getCli())
3135+
moduleTranslation.mapOmpLoop(cli, llvmCLI);
3136+
3137+
return success();
3138+
}
3139+
3140+
/// Apply a `#pragma omp unroll` / "!$omp unroll" transformation using the
3141+
/// OpenMPIRBuilder.
3142+
static LogicalResult
3143+
applyUnrollHeuristic(omp::UnrollHeuristicOp op, llvm::IRBuilderBase &builder,
3144+
LLVM::ModuleTranslation &moduleTranslation) {
3145+
llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
3146+
3147+
Value applyee = op.getApplyee();
3148+
assert(applyee && "Loop to apply unrolling on required");
3149+
3150+
llvm::CanonicalLoopInfo *consBuilderCLI =
3151+
moduleTranslation.lookupOMPLoop(applyee);
3152+
llvm::OpenMPIRBuilder::LocationDescription loc(builder);
3153+
ompBuilder->unrollLoopHeuristic(loc.DL, consBuilderCLI);
3154+
3155+
moduleTranslation.invalidateOmpLoop(applyee);
3156+
return success();
3157+
}
3158+
30983159
/// Convert an Atomic Ordering attribute to llvm::AtomicOrdering.
30993160
static llvm::AtomicOrdering
31003161
convertAtomicOrdering(std::optional<omp::ClauseMemoryOrderKind> ao) {
@@ -5989,6 +6050,23 @@ convertHostOrTargetOperation(Operation *op, llvm::IRBuilderBase &builder,
59896050
// etc. and then discarded
59906051
return success();
59916052
})
6053+
.Case([&](omp::NewCliOp op) {
6054+
// Meta-operation: Doesn't do anything by itself, but used to
6055+
// identify a loop.
6056+
return success();
6057+
})
6058+
.Case([&](omp::CanonicalLoopOp op) {
6059+
return convertOmpCanonicalLoopOp(op, builder, moduleTranslation);
6060+
})
6061+
.Case([&](omp::UnrollHeuristicOp op) {
6062+
// FIXME: Handling omp.unroll_heuristic as an executable requires
6063+
// that the generator (e.g. omp.canonical_loop) has been seen first.
6064+
// For construct that require all codegen to occur inside a callback
6065+
// (e.g. OpenMPIRBilder::createParallel), all codegen of that
6066+
// contained region including their transformations must occur at
6067+
// the omp.canonical_loop.
6068+
return applyUnrollHeuristic(op, builder, moduleTranslation);
6069+
})
59926070
.Default([&](Operation *inst) {
59936071
return inst->emitError()
59946072
<< "not yet implemented: " << inst->getName();
Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
// Test lowering of standalone omp.canonical_loop
2+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
3+
4+
// CHECK-LABEL: define void @anon_loop(
5+
// CHECK-SAME: ptr %[[ptr:.+]],
6+
// CHECK-SAME: i32 %[[tc:.+]]) {
7+
// CHECK-NEXT: br label %omp_omp.loop.preheader
8+
// CHECK-EMPTY:
9+
// CHECK-NEXT: omp_omp.loop.preheader:
10+
// CHECK-NEXT: br label %omp_omp.loop.header
11+
// CHECK-EMPTY:
12+
// CHECK-NEXT: omp_omp.loop.header:
13+
// CHECK-NEXT: %omp_omp.loop.iv = phi i32 [ 0, %omp_omp.loop.preheader ], [ %omp_omp.loop.next, %omp_omp.loop.inc ]
14+
// CHECK-NEXT: br label %omp_omp.loop.cond
15+
// CHECK-EMPTY:
16+
// CHECK-NEXT: omp_omp.loop.cond:
17+
// CHECK-NEXT: %omp_omp.loop.cmp = icmp ult i32 %omp_omp.loop.iv, %[[tc]]
18+
// CHECK-NEXT: br i1 %omp_omp.loop.cmp, label %omp_omp.loop.body, label %omp_omp.loop.exit
19+
// CHECK-EMPTY:
20+
// CHECK-NEXT: omp_omp.loop.body:
21+
// CHECK-NEXT: br label %omp.loop.region
22+
// CHECK-EMPTY:
23+
// CHECK-NEXT: omp.loop.region:
24+
// CHECK-NEXT: store float 4.200000e+01, ptr %[[ptr]], align 4
25+
// CHECK-NEXT: br label %omp.region.cont
26+
// CHECK-EMPTY:
27+
// CHECK-NEXT: omp.region.cont:
28+
// CHECK-NEXT: br label %omp_omp.loop.inc
29+
// CHECK-EMPTY:
30+
// CHECK-NEXT: omp_omp.loop.inc:
31+
// CHECK-NEXT: %omp_omp.loop.next = add nuw i32 %omp_omp.loop.iv, 1
32+
// CHECK-NEXT: br label %omp_omp.loop.header
33+
// CHECK-EMPTY:
34+
// CHECK-NEXT: omp_omp.loop.exit:
35+
// CHECK-NEXT: br label %omp_omp.loop.after
36+
// CHECK-EMPTY:
37+
// CHECK-NEXT: omp_omp.loop.after:
38+
// CHECK-NEXT: ret void
39+
// CHECK-NEXT: }
40+
llvm.func @anon_loop(%ptr: !llvm.ptr, %tc : i32) -> () {
41+
omp.canonical_loop %iv : i32 in range(%tc) {
42+
%val = llvm.mlir.constant(42.0 : f32) : f32
43+
llvm.store %val, %ptr : f32, !llvm.ptr
44+
omp.terminator
45+
}
46+
llvm.return
47+
}
48+
49+
50+
51+
// CHECK-LABEL: define void @trivial_loop(
52+
// CHECK-SAME: ptr %[[ptr:.+]],
53+
// CHECK-SAME: i32 %[[tc:.+]]) {
54+
// CHECK-NEXT: br label %omp_omp.loop.preheader
55+
// CHECK-EMPTY:
56+
// CHECK-NEXT: omp_omp.loop.preheader:
57+
// CHECK-NEXT: br label %omp_omp.loop.header
58+
// CHECK-EMPTY:
59+
// CHECK-NEXT: omp_omp.loop.header:
60+
// CHECK-NEXT: %omp_omp.loop.iv = phi i32 [ 0, %omp_omp.loop.preheader ], [ %omp_omp.loop.next, %omp_omp.loop.inc ]
61+
// CHECK-NEXT: br label %omp_omp.loop.cond
62+
// CHECK-EMPTY:
63+
// CHECK-NEXT: omp_omp.loop.cond:
64+
// CHECK-NEXT: %omp_omp.loop.cmp = icmp ult i32 %omp_omp.loop.iv, %[[tc]]
65+
// CHECK-NEXT: br i1 %omp_omp.loop.cmp, label %omp_omp.loop.body, label %omp_omp.loop.exit
66+
// CHECK-EMPTY:
67+
// CHECK-NEXT: omp_omp.loop.body:
68+
// CHECK-NEXT: br label %omp.loop.region
69+
// CHECK-EMPTY:
70+
// CHECK-NEXT: omp.loop.region:
71+
// CHECK-NEXT: store float 4.200000e+01, ptr %[[ptr]], align 4
72+
// CHECK-NEXT: br label %omp.region.cont
73+
// CHECK-EMPTY:
74+
// CHECK-NEXT: omp.region.cont:
75+
// CHECK-NEXT: br label %omp_omp.loop.inc
76+
// CHECK-EMPTY:
77+
// CHECK-NEXT: omp_omp.loop.inc:
78+
// CHECK-NEXT: %omp_omp.loop.next = add nuw i32 %omp_omp.loop.iv, 1
79+
// CHECK-NEXT: br label %omp_omp.loop.header
80+
// CHECK-EMPTY:
81+
// CHECK-NEXT: omp_omp.loop.exit:
82+
// CHECK-NEXT: br label %omp_omp.loop.after
83+
// CHECK-EMPTY:
84+
// CHECK-NEXT: omp_omp.loop.after:
85+
// CHECK-NEXT: ret void
86+
// CHECK-NEXT: }
87+
llvm.func @trivial_loop(%ptr: !llvm.ptr, %tc : i32) -> () {
88+
%cli = omp.new_cli
89+
omp.canonical_loop(%cli) %iv : i32 in range(%tc) {
90+
%val = llvm.mlir.constant(42.0 : f32) : f32
91+
llvm.store %val, %ptr : f32, !llvm.ptr
92+
omp.terminator
93+
}
94+
llvm.return
95+
}
96+
97+
98+
// CHECK-LABEL: define void @nested_loop(
99+
// CHECK-SAME: ptr %[[ptr:.+]], i32 %[[outer_tc:.+]], i32 %[[inner_tc:.+]]) {
100+
// CHECK-NEXT: br label %omp_omp.loop.preheader
101+
// CHECK-EMPTY:
102+
// CHECK-NEXT: omp_omp.loop.preheader:
103+
// CHECK-NEXT: br label %omp_omp.loop.header
104+
// CHECK-EMPTY:
105+
// CHECK-NEXT: omp_omp.loop.header:
106+
// CHECK-NEXT: %omp_omp.loop.iv = phi i32 [ 0, %omp_omp.loop.preheader ], [ %omp_omp.loop.next, %omp_omp.loop.inc ]
107+
// CHECK-NEXT: br label %omp_omp.loop.cond
108+
// CHECK-EMPTY:
109+
// CHECK-NEXT: omp_omp.loop.cond:
110+
// CHECK-NEXT: %omp_omp.loop.cmp = icmp ult i32 %omp_omp.loop.iv, %[[outer_tc]]
111+
// CHECK-NEXT: br i1 %omp_omp.loop.cmp, label %omp_omp.loop.body, label %omp_omp.loop.exit
112+
// CHECK-EMPTY:
113+
// CHECK-NEXT: omp_omp.loop.body:
114+
// CHECK-NEXT: br label %omp.loop.region
115+
// CHECK-EMPTY:
116+
// CHECK-NEXT: omp.loop.region:
117+
// CHECK-NEXT: br label %omp_omp.loop.preheader1
118+
// CHECK-EMPTY:
119+
// CHECK-NEXT: omp_omp.loop.preheader1:
120+
// CHECK-NEXT: br label %omp_omp.loop.header2
121+
// CHECK-EMPTY:
122+
// CHECK-NEXT: omp_omp.loop.header2:
123+
// CHECK-NEXT: %omp_omp.loop.iv8 = phi i32 [ 0, %omp_omp.loop.preheader1 ], [ %omp_omp.loop.next10, %omp_omp.loop.inc5 ]
124+
// CHECK-NEXT: br label %omp_omp.loop.cond3
125+
// CHECK-EMPTY:
126+
// CHECK-NEXT: omp_omp.loop.cond3:
127+
// CHECK-NEXT: %omp_omp.loop.cmp9 = icmp ult i32 %omp_omp.loop.iv8, %[[inner_tc]]
128+
// CHECK-NEXT: br i1 %omp_omp.loop.cmp9, label %omp_omp.loop.body4, label %omp_omp.loop.exit6
129+
// CHECK-EMPTY:
130+
// CHECK-NEXT: omp_omp.loop.body4:
131+
// CHECK-NEXT: br label %omp.loop.region12
132+
// CHECK-EMPTY:
133+
// CHECK-NEXT: omp.loop.region12:
134+
// CHECK-NEXT: store float 4.200000e+01, ptr %[[ptr]], align 4
135+
// CHECK-NEXT: br label %omp.region.cont11
136+
// CHECK-EMPTY:
137+
// CHECK-NEXT: omp.region.cont11:
138+
// CHECK-NEXT: br label %omp_omp.loop.inc5
139+
// CHECK-EMPTY:
140+
// CHECK-NEXT: omp_omp.loop.inc5:
141+
// CHECK-NEXT: %omp_omp.loop.next10 = add nuw i32 %omp_omp.loop.iv8, 1
142+
// CHECK-NEXT: br label %omp_omp.loop.header2
143+
// CHECK-EMPTY:
144+
// CHECK-NEXT: omp_omp.loop.exit6:
145+
// CHECK-NEXT: br label %omp_omp.loop.after7
146+
// CHECK-EMPTY:
147+
// CHECK-NEXT: omp_omp.loop.after7:
148+
// CHECK-NEXT: br label %omp.region.cont
149+
// CHECK-EMPTY:
150+
// CHECK-NEXT: omp.region.cont:
151+
// CHECK-NEXT: br label %omp_omp.loop.inc
152+
// CHECK-EMPTY:
153+
// CHECK-NEXT: omp_omp.loop.inc:
154+
// CHECK-NEXT: %omp_omp.loop.next = add nuw i32 %omp_omp.loop.iv, 1
155+
// CHECK-NEXT: br label %omp_omp.loop.header
156+
// CHECK-EMPTY:
157+
// CHECK-NEXT: omp_omp.loop.exit:
158+
// CHECK-NEXT: br label %omp_omp.loop.after
159+
// CHECK-EMPTY:
160+
// CHECK-NEXT: omp_omp.loop.after:
161+
// CHECK-NEXT: ret void
162+
// CHECK-NEXT: }
163+
llvm.func @nested_loop(%ptr: !llvm.ptr, %outer_tc : i32, %inner_tc : i32) -> () {
164+
%outer_cli = omp.new_cli
165+
%inner_cli = omp.new_cli
166+
omp.canonical_loop(%outer_cli) %outer_iv : i32 in range(%outer_tc) {
167+
omp.canonical_loop(%inner_cli) %inner_iv : i32 in range(%inner_tc) {
168+
%val = llvm.mlir.constant(42.0 : f32) : f32
169+
llvm.store %val, %ptr : f32, !llvm.ptr
170+
omp.terminator
171+
}
172+
omp.terminator
173+
}
174+
llvm.return
175+
}

0 commit comments

Comments
 (0)