Skip to content

Commit 8c9e0c6

Browse files
authored
[flang][OpenMP] Allocate reduction init temps on the stack for GPUs (#146667)
Temps needed for the reduction init regions are now allocate on the heap all the time. However, this is performance killer for GPUs since malloc calls are prohibitively expensive. Therefore, we should do these allocations on the stack for GPU reductions.
1 parent ddcccc4 commit 8c9e0c6

File tree

3 files changed

+124
-94
lines changed

3 files changed

+124
-94
lines changed

flang/lib/Lower/Support/PrivateReductionUtils.cpp

Lines changed: 31 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -502,22 +502,37 @@ void PopulateInitAndCleanupRegionsHelper::initAndCleanupBoxedArray(
502502

503503
// Allocating on the heap in case the whole reduction/privatization is nested
504504
// inside of a loop
505-
auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
506-
// if needsDealloc isn't statically false, add cleanup region. Always
507-
// do this for allocatable boxes because they might have been re-allocated
508-
// in the body of the loop/parallel region
509-
510-
std::optional<int64_t> cstNeedsDealloc = fir::getIntIfConstant(needsDealloc);
511-
assert(cstNeedsDealloc.has_value() &&
512-
"createTempFromMold decides this statically");
513-
if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
514-
mlir::OpBuilder::InsertionGuard guard(builder);
515-
createCleanupRegion(converter, loc, argType, cleanupRegion, sym,
516-
isDoConcurrent);
517-
} else {
518-
assert(!isAllocatableOrPointer &&
519-
"Pointer-like arrays must be heap allocated");
520-
}
505+
auto temp = [&]() {
506+
bool shouldAllocateOnStack = false;
507+
508+
// On the GPU, always allocate on the stack since heap allocatins are very
509+
// expensive.
510+
if (auto offloadMod = llvm::dyn_cast<mlir::omp::OffloadModuleInterface>(
511+
*builder.getModule()))
512+
shouldAllocateOnStack = offloadMod.getIsGPU();
513+
514+
if (shouldAllocateOnStack)
515+
return createStackTempFromMold(loc, builder, source);
516+
517+
auto [temp, needsDealloc] = createTempFromMold(loc, builder, source);
518+
// if needsDealloc isn't statically false, add cleanup region. Always
519+
// do this for allocatable boxes because they might have been re-allocated
520+
// in the body of the loop/parallel region
521+
522+
std::optional<int64_t> cstNeedsDealloc =
523+
fir::getIntIfConstant(needsDealloc);
524+
assert(cstNeedsDealloc.has_value() &&
525+
"createTempFromMold decides this statically");
526+
if (cstNeedsDealloc.has_value() && *cstNeedsDealloc != false) {
527+
mlir::OpBuilder::InsertionGuard guard(builder);
528+
createCleanupRegion(converter, loc, argType, cleanupRegion, sym,
529+
isDoConcurrent);
530+
} else {
531+
assert(!isAllocatableOrPointer &&
532+
"Pointer-like arrays must be heap allocated");
533+
}
534+
return temp;
535+
}();
521536

522537
// Put the temporary inside of a box:
523538
// hlfir::genVariableBox doesn't handle non-default lower bounds
Lines changed: 88 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
1-
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
2-
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s
1+
! RUN: bbc -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s --check-prefix=CPU
2+
! RUN: %flang_fc1 -emit-hlfir -fopenmp -o - %s 2>&1 | FileCheck %s --check-prefix=CPU
3+
4+
! RUN: bbc -emit-hlfir -fopenmp -fopenmp-is-target-device -fopenmp-is-gpu -o - %s 2>&1 | FileCheck %s --check-prefix=GPU
5+
! RUN: %flang_fc1 -triple amdgcn-amd-amdhsa -emit-hlfir -fopenmp -fopenmp-is-target-device -o - %s 2>&1 | FileCheck %s --check-prefix=GPU
36

47
program reduce
58
integer, dimension(3) :: i = 0
@@ -13,81 +16,88 @@ program reduce
1316
print *,i
1417
end program
1518

16-
! CHECK-LABEL: omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> alloc {
17-
! CHECK: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
18-
! CHECK: omp.yield(%[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
19-
! CHECK-LABEL: } init {
20-
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>, %[[ALLOC:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
21-
! CHECK: %[[VAL_2:.*]] = arith.constant 0 : i32
22-
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
23-
! CHECK: %[[VAL_4:.*]] = arith.constant 3 : index
24-
! CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
25-
! CHECK: %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32> {bindc_name = ".tmp", uniq_name = ""}
26-
! CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
27-
! CHECK: %[[TRUE:.*]] = arith.constant true
19+
! CPU-LABEL: omp.declare_reduction @add_reduction_byref_box_3xi32 : !fir.ref<!fir.box<!fir.array<3xi32>>> alloc {
20+
! CPU: %[[VAL_8:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
21+
! CPU: omp.yield(%[[VAL_8]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
22+
! CPU-LABEL: } init {
23+
! CPU: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>, %[[ALLOC:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
24+
! CPU: %[[VAL_2:.*]] = arith.constant 0 : i32
25+
! CPU: %[[VAL_3:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
26+
! CPU: %[[VAL_4:.*]] = arith.constant 3 : index
27+
! CPU: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
28+
! CPU: %[[VAL_1:.*]] = fir.allocmem !fir.array<3xi32> {bindc_name = ".tmp", uniq_name = ""}
29+
! CPU: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_1]](%[[VAL_5]]) {uniq_name = ".tmp"} : (!fir.heap<!fir.array<3xi32>>,
30+
! CPU: %[[TRUE:.*]] = arith.constant true
2831
!fir.shape<1>) -> (!fir.heap<!fir.array<3xi32>>, !fir.heap<!fir.array<3xi32>>)
29-
! CHECK: %[[C0:.*]] = arith.constant 0 : index
30-
! CHECK: %[[DIMS:.*]]:3 = fir.box_dims %[[VAL_3]], %[[C0]] : (!fir.box<!fir.array<3xi32>>, index) -> (index, index, index)
31-
! CHECK: %[[SHIFT:.*]] = fir.shape_shift %[[DIMS]]#0, %[[DIMS]]#1 : (index, index) -> !fir.shapeshift<1>
32-
! CHECK: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[SHIFT]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<3xi32>>
33-
! CHECK: hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
34-
! CHECK: fir.store %[[VAL_7]] to %[[ALLOC]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
35-
! CHECK: omp.yield(%[[ALLOC]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
36-
! CHECK: } combiner {
37-
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
38-
! CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
39-
! CHECK: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
40-
! CHECK: %[[C1:.*]] = arith.constant 1 : index
41-
! CHECK: %[[C3:.*]] = arith.constant 3 : index
42-
! CHECK: %[[SHAPE_SHIFT:.*]] = fir.shape_shift %[[C1]], %[[C3]] : (index, index) -> !fir.shapeshift<1>
43-
! CHECK: %[[C1_0:.*]] = arith.constant 1 : index
44-
! CHECK: fir.do_loop %[[VAL_8:.*]] = %[[C1_0]] to %[[C3]] step %[[C1_0]] unordered {
45-
! CHECK: %[[VAL_9:.*]] = fir.array_coor %[[VAL_2]](%[[SHAPE_SHIFT]]) %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
46-
! CHECK: %[[VAL_10:.*]] = fir.array_coor %[[VAL_3]](%[[SHAPE_SHIFT]]) %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
47-
! CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
48-
! CHECK: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
49-
! CHECK: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
50-
! CHECK: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
51-
! CHECK: }
52-
! CHECK: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
53-
! CHECK: } cleanup {
54-
! CHECK: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
55-
! CHECK: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
56-
! CHECK: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
57-
! CHECK: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
58-
! CHECK: %[[VAL_4:.*]] = arith.constant 0 : i64
59-
! CHECK: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
60-
! CHECK: fir.if %[[VAL_5]] {
61-
! CHECK: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
62-
! CHECK: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
63-
! CHECK: }
64-
! CHECK: omp.yield
65-
! CHECK: }
32+
! CPU: %[[C0:.*]] = arith.constant 0 : index
33+
! CPU: %[[DIMS:.*]]:3 = fir.box_dims %[[VAL_3]], %[[C0]] : (!fir.box<!fir.array<3xi32>>, index) -> (index, index, index)
34+
! CPU: %[[SHIFT:.*]] = fir.shape_shift %[[DIMS]]#0, %[[DIMS]]#1 : (index, index) -> !fir.shapeshift<1>
35+
! CPU: %[[VAL_7:.*]] = fir.embox %[[VAL_6]]#0(%[[SHIFT]]) : (!fir.heap<!fir.array<3xi32>>, !fir.shapeshift<1>) -> !fir.box<!fir.array<3xi32>>
36+
! CPU: hlfir.assign %[[VAL_2]] to %[[VAL_7]] : i32, !fir.box<!fir.array<3xi32>>
37+
! CPU: fir.store %[[VAL_7]] to %[[ALLOC]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
38+
! CPU: omp.yield(%[[ALLOC]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
39+
! CPU: } combiner {
40+
! CPU: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>, %[[VAL_1:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
41+
! CPU: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
42+
! CPU: %[[VAL_3:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
43+
! CPU: %[[C1:.*]] = arith.constant 1 : index
44+
! CPU: %[[C3:.*]] = arith.constant 3 : index
45+
! CPU: %[[SHAPE_SHIFT:.*]] = fir.shape_shift %[[C1]], %[[C3]] : (index, index) -> !fir.shapeshift<1>
46+
! CPU: %[[C1_0:.*]] = arith.constant 1 : index
47+
! CPU: fir.do_loop %[[VAL_8:.*]] = %[[C1_0]] to %[[C3]] step %[[C1_0]] unordered {
48+
! CPU: %[[VAL_9:.*]] = fir.array_coor %[[VAL_2]](%[[SHAPE_SHIFT]]) %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
49+
! CPU: %[[VAL_10:.*]] = fir.array_coor %[[VAL_3]](%[[SHAPE_SHIFT]]) %[[VAL_8]] : (!fir.box<!fir.array<3xi32>>, !fir.shapeshift<1>, index) -> !fir.ref<i32>
50+
! CPU: %[[VAL_11:.*]] = fir.load %[[VAL_9]] : !fir.ref<i32>
51+
! CPU: %[[VAL_12:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
52+
! CPU: %[[VAL_13:.*]] = arith.addi %[[VAL_11]], %[[VAL_12]] : i32
53+
! CPU: fir.store %[[VAL_13]] to %[[VAL_9]] : !fir.ref<i32>
54+
! CPU: }
55+
! CPU: omp.yield(%[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>)
56+
! CPU: } cleanup {
57+
! CPU: ^bb0(%[[VAL_0:.*]]: !fir.ref<!fir.box<!fir.array<3xi32>>>):
58+
! CPU: %[[VAL_1:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
59+
! CPU: %[[VAL_2:.*]] = fir.box_addr %[[VAL_1]] : (!fir.box<!fir.array<3xi32>>) -> !fir.ref<!fir.array<3xi32>>
60+
! CPU: %[[VAL_3:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> i64
61+
! CPU: %[[VAL_4:.*]] = arith.constant 0 : i64
62+
! CPU: %[[VAL_5:.*]] = arith.cmpi ne, %[[VAL_3]], %[[VAL_4]] : i64
63+
! CPU: fir.if %[[VAL_5]] {
64+
! CPU: %[[VAL_6:.*]] = fir.convert %[[VAL_2]] : (!fir.ref<!fir.array<3xi32>>) -> !fir.heap<!fir.array<3xi32>>
65+
! CPU: fir.freemem %[[VAL_6]] : !fir.heap<!fir.array<3xi32>>
66+
! CPU: }
67+
! CPU: omp.yield
68+
! CPU: }
69+
70+
! CPU-LABEL: func.func @_QQmain()
71+
! CPU: %[[VAL_0:.*]] = fir.address_of(@_QFEi) : !fir.ref<!fir.array<3xi32>>
72+
! CPU: %[[VAL_1:.*]] = arith.constant 3 : index
73+
! CPU: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
74+
! CPU: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_2]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
75+
! CPU: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
76+
! CPU: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
77+
! CPU: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
78+
! CPU: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
79+
! CPU: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
80+
! CPU: %[[VAL_8:.*]] = arith.constant 1 : i32
81+
! CPU: %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
82+
! CPU: %[[VAL_10:.*]] = arith.constant 1 : index
83+
! CPU: %[[VAL_11:.*]] = hlfir.designate %[[VAL_9]] (%[[VAL_10]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
84+
! CPU: hlfir.assign %[[VAL_8]] to %[[VAL_11]] : i32, !fir.ref<i32>
85+
! CPU: %[[VAL_12:.*]] = arith.constant 2 : i32
86+
! CPU: %[[VAL_13:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
87+
! CPU: %[[VAL_14:.*]] = arith.constant 2 : index
88+
! CPU: %[[VAL_15:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_14]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
89+
! CPU: hlfir.assign %[[VAL_12]] to %[[VAL_15]] : i32, !fir.ref<i32>
90+
! CPU: %[[VAL_16:.*]] = arith.constant 3 : i32
91+
! CPU: %[[VAL_17:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
92+
! CPU: %[[VAL_18:.*]] = arith.constant 3 : index
93+
! CPU: %[[VAL_19:.*]] = hlfir.designate %[[VAL_17]] (%[[VAL_18]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
94+
! CPU: hlfir.assign %[[VAL_16]] to %[[VAL_19]] : i32, !fir.ref<i32>
95+
! CPU: omp.terminator
96+
! CPU: }
6697

67-
! CHECK-LABEL: func.func @_QQmain()
68-
! CHECK: %[[VAL_0:.*]] = fir.address_of(@_QFEi) : !fir.ref<!fir.array<3xi32>>
69-
! CHECK: %[[VAL_1:.*]] = arith.constant 3 : index
70-
! CHECK: %[[VAL_2:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
71-
! CHECK: %[[VAL_3:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_2]]) {uniq_name = "_QFEi"} : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<3xi32>>, !fir.ref<!fir.array<3xi32>>)
72-
! CHECK: %[[VAL_4:.*]] = fir.embox %[[VAL_3]]#0(%[[VAL_2]]) : (!fir.ref<!fir.array<3xi32>>, !fir.shape<1>) -> !fir.box<!fir.array<3xi32>>
73-
! CHECK: %[[VAL_5:.*]] = fir.alloca !fir.box<!fir.array<3xi32>>
74-
! CHECK: fir.store %[[VAL_4]] to %[[VAL_5]] : !fir.ref<!fir.box<!fir.array<3xi32>>>
75-
! CHECK: omp.parallel reduction(byref @add_reduction_byref_box_3xi32 %[[VAL_5]] -> %[[VAL_6:.*]] : !fir.ref<!fir.box<!fir.array<3xi32>>>) {
76-
! CHECK: %[[VAL_7:.*]]:2 = hlfir.declare %[[VAL_6]] {uniq_name = "_QFEi"} : (!fir.ref<!fir.box<!fir.array<3xi32>>>) -> (!fir.ref<!fir.box<!fir.array<3xi32>>>, !fir.ref<!fir.box<!fir.array<3xi32>>>)
77-
! CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
78-
! CHECK: %[[VAL_9:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
79-
! CHECK: %[[VAL_10:.*]] = arith.constant 1 : index
80-
! CHECK: %[[VAL_11:.*]] = hlfir.designate %[[VAL_9]] (%[[VAL_10]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
81-
! CHECK: hlfir.assign %[[VAL_8]] to %[[VAL_11]] : i32, !fir.ref<i32>
82-
! CHECK: %[[VAL_12:.*]] = arith.constant 2 : i32
83-
! CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
84-
! CHECK: %[[VAL_14:.*]] = arith.constant 2 : index
85-
! CHECK: %[[VAL_15:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_14]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
86-
! CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_15]] : i32, !fir.ref<i32>
87-
! CHECK: %[[VAL_16:.*]] = arith.constant 3 : i32
88-
! CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_7]]#0 : !fir.ref<!fir.box<!fir.array<3xi32>>>
89-
! CHECK: %[[VAL_18:.*]] = arith.constant 3 : index
90-
! CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_17]] (%[[VAL_18]]) : (!fir.box<!fir.array<3xi32>>, index) -> !fir.ref<i32>
91-
! CHECK: hlfir.assign %[[VAL_16]] to %[[VAL_19]] : i32, !fir.ref<i32>
92-
! CHECK: omp.terminator
93-
! CHECK: }
98+
! GPU: omp.declare_reduction {{.*}} alloc {
99+
! GPU: } init {
100+
! GPU-NOT: fir.allocmem {{.*}} {bindc_name = ".tmp", {{.*}}}
101+
! GPU: fir.alloca {{.*}} {bindc_name = ".tmp"}
102+
! GPU: } combiner {
103+
! GPU: }

mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1291,6 +1291,11 @@ initReductionVars(OP op, ArrayRef<BlockArgument> reductionArgs,
12911291
mapInitializationArgs(op, moduleTranslation, reductionDecls,
12921292
reductionVariableMap, i);
12931293

1294+
// TODO In some cases (specially on the GPU), the init regions may
1295+
// contains stack alloctaions. If the region is inlined in a loop, this is
1296+
// problematic. Instead of just inlining the region, handle allocations by
1297+
// hoisting fixed length allocations to the function entry and using
1298+
// stacksave and restore for variable length ones.
12941299
if (failed(inlineConvertOmpRegions(reductionDecls[i].getInitializerRegion(),
12951300
"omp.reduction.neutral", builder,
12961301
moduleTranslation, &phis)))

0 commit comments

Comments
 (0)