Skip to content

Commit a8d8af3

Browse files
authored
[OpenMP][OMPIRBuilder] Collect users of a value before replacing them in target outlined function (#139064)
This PR fixes a crash that curently happens given the following input: ```fortran subroutine caller() real :: x integer :: i !$omp target x = i call callee(x,x) !$omp end target endsubroutine caller subroutine callee(x1,x2) real :: x1, x2 endsubroutine callee ``` The crash happens because the following sequence of events is taken by the `OMPIRBuilder`: 1. .... 2. An outlined function for the target region is created. At first the outlined function still refers to the SSA values from the original function of the target region. 3. The builder then iterates over the users of SSA values used in the target region to replace them with the corresponding function arguments of outlined function. 4. If the same instruction references the SSA value more than once (say m), all uses of that SSA value are replaced in the instruction. Deleting all m uses of the value. 5. The next m-1 iterations will still iterate over the same instruction dropping the last m-1 actual users of the value. Hence, we collect all users first before modifying them.
1 parent 4000113 commit a8d8af3

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7089,8 +7089,12 @@ static Expected<Function *> createOutlinedFunction(
70897089
if (auto *Const = dyn_cast<Constant>(Input))
70907090
convertUsersOfConstantsToInstructions(Const, Func, false);
70917091

7092+
// Collect users before iterating over them to avoid invalidating the
7093+
// iteration in case a user uses Input more than once (e.g. a call
7094+
// instruction).
7095+
SetVector<User *> Users(Input->users().begin(), Input->users().end());
70927096
// Collect all the instructions
7093-
for (User *User : make_early_inc_range(Input->users()))
7097+
for (User *User : make_early_inc_range(Users))
70947098
if (auto *Instr = dyn_cast<Instruction>(User))
70957099
if (Instr->getFunction() == Func)
70967100
Instr->replaceUsesOfWith(Input, InputCopy);
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s
2+
3+
llvm.func @caller_() {
4+
%c1 = llvm.mlir.constant(1 : i64) : i64
5+
%x_host = llvm.alloca %c1 x f32 {bindc_name = "x"} : (i64) -> !llvm.ptr
6+
%i_host = llvm.alloca %c1 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
7+
%x_map = omp.map.info var_ptr(%x_host : !llvm.ptr, f32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "x"}
8+
%i_map = omp.map.info var_ptr(%i_host : !llvm.ptr, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !llvm.ptr {name = "i"}
9+
omp.target map_entries(%x_map -> %x_arg, %i_map -> %i_arg : !llvm.ptr, !llvm.ptr) {
10+
%1 = llvm.load %i_arg : !llvm.ptr -> i32
11+
%2 = llvm.sitofp %1 : i32 to f32
12+
llvm.store %2, %x_arg : f32, !llvm.ptr
13+
// The call instruction uses %x_arg more than once. Hence modifying users
14+
// while iterating them invalidates the iteration. Which is what is tested
15+
// by this test.
16+
llvm.call @callee_(%x_arg, %x_arg) : (!llvm.ptr, !llvm.ptr) -> ()
17+
omp.terminator
18+
}
19+
llvm.return
20+
}
21+
22+
llvm.func @callee_(%arg0: !llvm.ptr, %arg1: !llvm.ptr) {
23+
llvm.return
24+
}
25+
26+
27+
// CHECK: define internal void @__omp_offloading_{{.*}}_caller__{{.*}}(ptr %[[X_PARAM:.*]], ptr %[[I_PARAM:.*]]) {
28+
29+
// CHECK: %[[I_VAL:.*]] = load i32, ptr %[[I_PARAM]], align 4
30+
// CHECK: %[[I_VAL_FL:.*]] = sitofp i32 %[[I_VAL]] to float
31+
// CHECK: store float %[[I_VAL_FL]], ptr %[[X_PARAM]], align 4
32+
// CHECK: call void @callee_(ptr %[[X_PARAM]], ptr %[[X_PARAM]])
33+
// CHECK: br label %[[REGION_CONT:.*]]
34+
35+
// CHECK: [[REGION_CONT]]:
36+
// CHECK: ret void
37+
// CHECK: }

0 commit comments

Comments
 (0)