Skip to content

Commit 09b012f

Browse files
authored
[flang][openacc] Fix wait clause printer (#137263)
wait clause printer is failing with case like: ``` !$acc serial device_type(nvidia) wait !$acc end serial ```
1 parent 9cbbb74 commit 09b012f

File tree

2 files changed

+21
-14
lines changed

2 files changed

+21
-14
lines changed

flang/test/Lower/OpenACC/acc-serial.f90

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,11 @@ subroutine acc_serial
8787
! CHECK: acc.yield
8888
! CHECK-NEXT: }
8989

90+
!$acc serial device_type(nvidia) wait
91+
!$acc end serial
92+
93+
! CHECK: acc.serial wait([#acc.device_type<nvidia>])
94+
9095
!$acc serial wait(1)
9196
!$acc end serial
9297

mlir/lib/Dialect/OpenACC/IR/OpenACC.cpp

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,20 +1411,22 @@ static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op,
14111411
if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
14121412
p << ", ";
14131413

1414-
unsigned opIdx = 0;
1415-
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1416-
p << "{";
1417-
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1418-
if (boolAttr && boolAttr.getValue())
1419-
p << "devnum: ";
1420-
llvm::interleaveComma(
1421-
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1422-
p << operands[opIdx] << " : " << operands[opIdx].getType();
1423-
++opIdx;
1424-
});
1425-
p << "}";
1426-
printSingleDeviceType(p, it.value());
1427-
});
1414+
if (hasDeviceTypeValues(deviceTypes)) {
1415+
unsigned opIdx = 0;
1416+
llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
1417+
p << "{";
1418+
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1419+
if (boolAttr && boolAttr.getValue())
1420+
p << "devnum: ";
1421+
llvm::interleaveComma(
1422+
llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
1423+
p << operands[opIdx] << " : " << operands[opIdx].getType();
1424+
++opIdx;
1425+
});
1426+
p << "}";
1427+
printSingleDeviceType(p, it.value());
1428+
});
1429+
}
14281430

14291431
p << ")";
14301432
}

0 commit comments

Comments
 (0)