Skip to content

Commit 085062f

Browse files
committed
[OpenMP] workdistribute trivial lowering
Lowering logic inspired from ivanradanov coexeute lowering f56da1a
1 parent 8077858 commit 085062f

File tree

4 files changed

+158
-0
lines changed

4 files changed

+158
-0
lines changed

flang/include/flang/Optimizer/OpenMP/Passes.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,10 @@ def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
9393
let summary = "Lower workshare construct";
9494
}
9595

96+
def LowerWorkdistribute : Pass<"lower-workdistribute", "::mlir::ModuleOp"> {
97+
let summary = "Lower workdistribute construct";
98+
}
99+
96100
def GenericLoopConversionPass
97101
: Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
98102
let summary = "Converts OpenMP generic `omp.loop` to semantically "

flang/lib/Optimizer/OpenMP/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ add_flang_library(FlangOpenMPTransforms
77
MapsForPrivatizedSymbols.cpp
88
MapInfoFinalization.cpp
99
MarkDeclareTarget.cpp
10+
LowerWorkdistribute.cpp
1011
LowerWorkshare.cpp
1112
LowerNontemporal.cpp
1213

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//===- LowerWorkshare.cpp - special cases for bufferization -------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file implements the lowering of omp.workdistribute.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include <flang/Optimizer/Builder/FIRBuilder.h>
14+
#include <flang/Optimizer/Dialect/FIROps.h>
15+
#include <flang/Optimizer/Dialect/FIRType.h>
16+
#include <flang/Optimizer/HLFIR/HLFIROps.h>
17+
#include <flang/Optimizer/OpenMP/Passes.h>
18+
#include <llvm/ADT/BreadthFirstIterator.h>
19+
#include <llvm/ADT/STLExtras.h>
20+
#include <llvm/ADT/SmallVectorExtras.h>
21+
#include <llvm/ADT/iterator_range.h>
22+
#include <llvm/Support/ErrorHandling.h>
23+
#include <mlir/Dialect/Arith/IR/Arith.h>
24+
#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
25+
#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
26+
#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
27+
#include <mlir/Dialect/SCF/IR/SCF.h>
28+
#include <mlir/IR/BuiltinOps.h>
29+
#include <mlir/IR/IRMapping.h>
30+
#include <mlir/IR/OpDefinition.h>
31+
#include <mlir/IR/PatternMatch.h>
32+
#include <mlir/IR/Value.h>
33+
#include <mlir/IR/Visitors.h>
34+
#include <mlir/Interfaces/SideEffectInterfaces.h>
35+
#include <mlir/Support/LLVM.h>
36+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
37+
38+
#include <variant>
39+
40+
namespace flangomp {
41+
#define GEN_PASS_DEF_LOWERWORKDISTRIBUTE
42+
#include "flang/Optimizer/OpenMP/Passes.h.inc"
43+
} // namespace flangomp
44+
45+
#define DEBUG_TYPE "lower-workdistribute"
46+
47+
using namespace mlir;
48+
49+
namespace {
50+
51+
struct WorkdistributeToSingle : public mlir::OpRewritePattern<mlir::omp::WorkdistributeOp> {
52+
using OpRewritePattern::OpRewritePattern;
53+
mlir::LogicalResult
54+
matchAndRewrite(mlir::omp::WorkdistributeOp workdistribute,
55+
mlir::PatternRewriter &rewriter) const override {
56+
auto loc = workdistribute->getLoc();
57+
auto teams = llvm::dyn_cast<mlir::omp::TeamsOp>(workdistribute->getParentOp());
58+
if (!teams) {
59+
mlir::emitError(loc, "workdistribute not nested in teams\n");
60+
return mlir::failure();
61+
}
62+
if (workdistribute.getRegion().getBlocks().size() != 1) {
63+
mlir::emitError(loc, "workdistribute with multiple blocks\n");
64+
return mlir::failure();
65+
}
66+
if (teams.getRegion().getBlocks().size() != 1) {
67+
mlir::emitError(loc, "teams with multiple blocks\n");
68+
return mlir::failure();
69+
}
70+
if (teams.getRegion().getBlocks().front().getOperations().size() != 2) {
71+
mlir::emitError(loc, "teams with multiple nested ops\n");
72+
return mlir::failure();
73+
}
74+
mlir::Block *workdistributeBlock = &workdistribute.getRegion().front();
75+
rewriter.eraseOp(workdistributeBlock->getTerminator());
76+
rewriter.inlineBlockBefore(workdistributeBlock, teams);
77+
rewriter.eraseOp(teams);
78+
return mlir::success();
79+
}
80+
};
81+
82+
class LowerWorkdistributePass
83+
: public flangomp::impl::LowerWorkdistributeBase<LowerWorkdistributePass> {
84+
public:
85+
void runOnOperation() override {
86+
mlir::MLIRContext &context = getContext();
87+
mlir::RewritePatternSet patterns(&context);
88+
mlir::GreedyRewriteConfig config;
89+
// prevent the pattern driver form merging blocks
90+
config.setRegionSimplificationLevel(
91+
mlir::GreedySimplifyRegionLevel::Disabled);
92+
93+
patterns.insert<WorkdistributeToSingle>(&context);
94+
mlir::Operation *op = getOperation();
95+
if (mlir::failed(mlir::applyPatternsGreedily(op, std::move(patterns), config))) {
96+
mlir::emitError(op->getLoc(), DEBUG_TYPE " pass failed\n");
97+
signalPassFailure();
98+
}
99+
}
100+
};
101+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
// RUN: fir-opt --lower-workdistribute %s | FileCheck %s
2+
3+
// CHECK-LABEL: func.func @_QPtarget_simple() {
4+
// CHECK: %[[VAL_0:.*]] = arith.constant 2 : i32
5+
// CHECK: %[[VAL_1:.*]] = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
6+
// CHECK: %[[VAL_2:.*]]:2 = hlfir.declare %[[VAL_1]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
7+
// CHECK: %[[VAL_3:.*]] = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
8+
// CHECK: %[[VAL_4:.*]] = fir.zero_bits !fir.heap<i32>
9+
// CHECK: %[[VAL_5:.*]] = fir.embox %[[VAL_4]] : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
10+
// CHECK: fir.store %[[VAL_5]] to %[[VAL_3]] : !fir.ref<!fir.box<!fir.heap<i32>>>
11+
// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_3]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
12+
// CHECK: hlfir.assign %[[VAL_0]] to %[[VAL_2]]#0 : i32, !fir.ref<i32>
13+
// CHECK: %[[VAL_7:.*]] = omp.map.info var_ptr(%[[VAL_2]]#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
14+
// CHECK: omp.target map_entries(%[[VAL_7]] -> %[[VAL_8:.*]] : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %[[VAL_6]]#0 -> %[[VAL_9:.*]] : !fir.ref<!fir.box<!fir.heap<i32>>>) {
15+
// CHECK: %[[VAL_10:.*]] = arith.constant 10 : i32
16+
// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_8]] {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
17+
// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_9]] {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
18+
// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_11]]#0 : !fir.ref<i32>
19+
// CHECK: %[[VAL_14:.*]] = arith.addi %[[VAL_13]], %[[VAL_10]] : i32
20+
// CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_12]]#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
21+
// CHECK: omp.terminator
22+
// CHECK: }
23+
// CHECK: return
24+
// CHECK: }
25+
func.func @_QPtarget_simple() {
26+
%0 = fir.alloca i32 {bindc_name = "a", uniq_name = "_QFtarget_simpleEa"}
27+
%1:2 = hlfir.declare %0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
28+
%2 = fir.alloca !fir.box<!fir.heap<i32>> {bindc_name = "simple_var", uniq_name = "_QFtarget_simpleEsimple_var"}
29+
%3 = fir.zero_bits !fir.heap<i32>
30+
%4 = fir.embox %3 : (!fir.heap<i32>) -> !fir.box<!fir.heap<i32>>
31+
fir.store %4 to %2 : !fir.ref<!fir.box<!fir.heap<i32>>>
32+
%5:2 = hlfir.declare %2 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
33+
%c2_i32 = arith.constant 2 : i32
34+
hlfir.assign %c2_i32 to %1#0 : i32, !fir.ref<i32>
35+
%6 = omp.map.info var_ptr(%1#1 : !fir.ref<i32>, i32) map_clauses(to) capture(ByRef) -> !fir.ref<i32> {name = "a"}
36+
omp.target map_entries(%6 -> %arg0 : !fir.ref<i32>) private(@_QFtarget_simpleEsimple_var_private_ref_box_heap_i32 %5#0 -> %arg1 : !fir.ref<!fir.box<!fir.heap<i32>>>){
37+
omp.teams {
38+
omp.workdistribute {
39+
%11:2 = hlfir.declare %arg0 {uniq_name = "_QFtarget_simpleEa"} : (!fir.ref<i32>) -> (!fir.ref<i32>, !fir.ref<i32>)
40+
%12:2 = hlfir.declare %arg1 {fortran_attrs = #fir.var_attrs<allocatable>, uniq_name = "_QFtarget_simpleEsimple_var"} : (!fir.ref<!fir.box<!fir.heap<i32>>>) -> (!fir.ref<!fir.box<!fir.heap<i32>>>, !fir.ref<!fir.box<!fir.heap<i32>>>)
41+
%c10_i32 = arith.constant 10 : i32
42+
%13 = fir.load %11#0 : !fir.ref<i32>
43+
%14 = arith.addi %c10_i32, %13 : i32
44+
hlfir.assign %14 to %12#0 realloc : i32, !fir.ref<!fir.box<!fir.heap<i32>>>
45+
omp.terminator
46+
}
47+
omp.terminator
48+
}
49+
omp.terminator
50+
}
51+
return
52+
}

0 commit comments

Comments
 (0)