Skip to content

Commit 6ecc39f

Browse files
committed
Wrap omp.target with omp.target_data
1 parent fdc6938 commit 6ecc39f

File tree

3 files changed

+124
-52
lines changed

3 files changed

+124
-52
lines changed

flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,85 @@ struct TeamsWorkdistributeToSingle : public OpRewritePattern<omp::TeamsOp> {
346346
}
347347
};
348348

349+
static std::optional<std::tuple<Operation *, bool, bool>>
350+
getNestedOpToIsolate(omp::TargetOp targetOp) {
351+
auto *targetBlock = &targetOp.getRegion().front();
352+
for (auto &op : *targetBlock) {
353+
bool first = &op == &*targetBlock->begin();
354+
bool last = op.getNextNode() == targetBlock->getTerminator();
355+
if (first && last)
356+
return std::nullopt;
357+
358+
if (isa<omp::TeamsOp, omp::ParallelOp>(&op))
359+
return {{&op, first, last}};
360+
}
361+
return std::nullopt;
362+
}
363+
364+
struct SplitTargetResult {
365+
omp::TargetOp targetOp;
366+
omp::TargetDataOp dataOp;
367+
};
368+
369+
/// If multiple coexecutes are nested in a target regions, we will need to split
370+
/// the target region, but we want to preserve the data semantics of the
371+
/// original data region and avoid unnecessary data movement at each of the
372+
/// subkernels - we split the target region into a target_data{target}
373+
/// nest where only the outer one moves the data
374+
std::optional<SplitTargetResult> splitTargetData(omp::TargetOp targetOp,
375+
RewriterBase &rewriter) {
376+
377+
auto loc = targetOp->getLoc();
378+
if (targetOp.getMapVars().empty()) {
379+
LLVM_DEBUG(llvm::dbgs() << DEBUG_TYPE << " target region has no data maps\n");
380+
return std::nullopt;
381+
}
382+
383+
// Collect all map_entries with capture(ByRef)
384+
SmallVector<mlir::Value> byRefMapInfos;
385+
SmallVector<omp::MapInfoOp> MapInfos;
386+
for (auto opr : targetOp.getMapVars()) {
387+
auto mapInfo = cast<omp::MapInfoOp>(opr.getDefiningOp());
388+
MapInfos.push_back(mapInfo);
389+
if (mapInfo.getMapCaptureType() == omp::VariableCaptureKind::ByRef)
390+
byRefMapInfos.push_back(opr);
391+
}
392+
393+
// Create the new omp.target_data op with these collected map_entries
394+
auto targetLoc = targetOp.getLoc();
395+
rewriter.setInsertionPoint(targetOp);
396+
auto device = targetOp.getDevice();
397+
auto ifExpr = targetOp.getIfExpr();
398+
auto deviceAddrVars = targetOp.getHasDeviceAddrVars();
399+
auto devicePtrVars = targetOp.getIsDevicePtrVars();
400+
auto targetDataOp = rewriter.create<omp::TargetDataOp>(loc, device, ifExpr,
401+
mlir::ValueRange{byRefMapInfos},
402+
deviceAddrVars,
403+
devicePtrVars);
404+
405+
auto taregtDataBlock = rewriter.createBlock(&targetDataOp.getRegion());
406+
rewriter.create<mlir::omp::TerminatorOp>(loc);
407+
rewriter.setInsertionPointToStart(taregtDataBlock);
408+
409+
// Clone mapInfo ops inside omp.target_data region
410+
IRMapping mapping;
411+
for (auto mapInfo : MapInfos) {
412+
rewriter.clone(*mapInfo, mapping);
413+
}
414+
// Clone omp.target from exisiting targetOp inside target_data region.
415+
auto newTargetOp = rewriter.clone(*targetOp, mapping);
416+
417+
// Erase TargetOp and its MapInfoOps
418+
rewriter.eraseOp(targetOp);
419+
420+
for (auto mapInfo : MapInfos) {
421+
auto mapInfoRes = mapInfo.getResult();
422+
if (mapInfoRes.getUsers().empty())
423+
rewriter.eraseOp(mapInfo);
424+
}
425+
return SplitTargetResult{targetOp, targetDataOp};
426+
}
427+
349428
class LowerWorkdistributePass
350429
: public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
351430
public:
@@ -372,6 +451,15 @@ class LowerWorkdistributePass
372451
signalPassFailure();
373452
}
374453
}
454+
{
455+
SmallVector<omp::TargetOp> targetOps;
456+
op->walk([&](omp::TargetOp targetOp) { targetOps.push_back(targetOp); });
457+
IRRewriter rewriter(&context);
458+
for (auto targetOp : targetOps) {
459+
auto res = splitTargetData(targetOp, rewriter);
460+
}
461+
}
462+
375463
}
376464
};
377465
} // namespace
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @test_nested_derived_type_map_operand_and_block_addition(
4+
// CHECK-SAME: %[[ARG0:.*]]: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
5+
// CHECK: %[[VAL_0:.*]] = fir.declare %[[ARG0]] {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
6+
// CHECK: %[[VAL_1:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
7+
// CHECK: %[[VAL_2:.*]] = fir.coordinate_of %[[VAL_1]], i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
8+
// CHECK: %[[VAL_3:.*]] = omp.map.info var_ptr(%[[VAL_2]] : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%[[VAL_4:.*]]%[[VAL_5:.*]]"}
9+
// CHECK: %[[VAL_6:.*]] = fir.coordinate_of %[[VAL_0]], n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
10+
// CHECK: %[[VAL_7:.*]] = fir.coordinate_of %[[VAL_6]], r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
11+
// CHECK: %[[VAL_8:.*]] = omp.map.info var_ptr(%[[VAL_7]] : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%[[VAL_4]]%[[VAL_9:.*]]"}
12+
// CHECK: %[[VAL_10:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
13+
// CHECK: omp.target_data map_entries(%[[VAL_10]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
14+
// CHECK: %[[VAL_11:.*]] = omp.map.info var_ptr(%[[VAL_0]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%[[VAL_3]], %[[VAL_8]] : [1, 0], [1, 1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
15+
// CHECK: omp.target map_entries(%[[VAL_11]] -> %[[VAL_12:.*]] : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
16+
// CHECK: omp.terminator
17+
// CHECK: }
18+
// CHECK: omp.terminator
19+
// CHECK: }
20+
// CHECK: return
21+
// CHECK: }
22+
23+
func.func @test_nested_derived_type_map_operand_and_block_addition(%arg0: !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
24+
%0 = fir.declare %arg0 {uniq_name = "_QFmaptype_derived_nested_explicit_multiple_membersEsa"} : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>
25+
%2 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
26+
%4 = fir.coordinate_of %2, i : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<i32>
27+
%5 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<i32> {name = "sa%n%i"}
28+
%7 = fir.coordinate_of %0, n : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>
29+
%9 = fir.coordinate_of %7, r : (!fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>>) -> !fir.ref<f32>
30+
%10 = omp.map.info var_ptr(%9 : !fir.ref<f32>, f32) map_clauses(tofrom) capture(ByRef) -> !fir.ref<f32> {name = "sa%n%r"}
31+
%11 = omp.map.info var_ptr(%0 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>, !fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>) map_clauses(tofrom) capture(ByRef) members(%5, %10 : [1,0], [1,1] : !fir.ref<i32>, !fir.ref<f32>) -> !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>> {name = "sa", partial_map = true}
32+
omp.target map_entries(%11 -> %arg1 : !fir.ref<!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTscalar_and_array{r:f32,n:!fir.type<_QFmaptype_derived_nested_explicit_multiple_membersTnested{i:i32,r:f32}>}>>) {
33+
omp.terminator
34+
}
35+
return
36+
}

flang/test/Transforms/OpenMP/lower-workdistribute-to-single.mlir

Lines changed: 0 additions & 52 deletions
This file was deleted.

0 commit comments

Comments
 (0)