Skip to content

Commit aa7bf60

Browse files
authored
[flang][cuda] Fix false positive on unsupported CUDA data transfer (#148295)
The switch to `GetSymbolVector` introduced a regression on detecting implicit data transfer when the rhs is a function call. Make sure the symbol we are looking at are of interest to detect data transfer.
1 parent f090554 commit aa7bf60

File tree

2 files changed

+23
-10
lines changed

2 files changed

+23
-10
lines changed

flang/lib/Evaluate/tools.cpp

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1157,23 +1157,28 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols(
11571157
bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
11581158
semantics::UnorderedSymbolSet hostSymbols;
11591159
semantics::UnorderedSymbolSet deviceSymbols;
1160+
semantics::UnorderedSymbolSet cudaSymbols{CollectCudaSymbols(expr)};
11601161

11611162
SymbolVector symbols{GetSymbolVector(expr)};
11621163
std::reverse(symbols.begin(), symbols.end());
11631164
bool skipNext{false};
11641165
for (const Symbol &sym : symbols) {
1165-
bool isComponent{sym.owner().IsDerivedType()};
1166-
bool skipComponent{false};
1167-
if (!skipNext) {
1168-
if (IsCUDADeviceSymbol(sym)) {
1169-
deviceSymbols.insert(sym);
1170-
} else if (isComponent) {
1171-
skipComponent = true; // Component is not device. Look on the base.
1172-
} else {
1173-
hostSymbols.insert(sym);
1166+
if (cudaSymbols.find(sym) != cudaSymbols.end()) {
1167+
bool isComponent{sym.owner().IsDerivedType()};
1168+
bool skipComponent{false};
1169+
if (!skipNext) {
1170+
if (IsCUDADeviceSymbol(sym)) {
1171+
deviceSymbols.insert(sym);
1172+
} else if (isComponent) {
1173+
skipComponent = true; // Component is not device. Look on the base.
1174+
} else {
1175+
hostSymbols.insert(sym);
1176+
}
11741177
}
1178+
skipNext = isComponent && !skipComponent;
1179+
} else {
1180+
skipNext = false;
11751181
}
1176-
skipNext = isComponent && !skipComponent;
11771182
}
11781183
bool hasConstant{HasConstant(expr)};
11791184
return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0;

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,3 +436,11 @@ end subroutine
436436

437437
! CHECK-LABEL: func.func @_QPsub22()
438438
! CHECK: cuf.data_transfer
439+
440+
subroutine sub23(n)
441+
integer :: n
442+
real(8), device :: d(n,n), x(n)
443+
x = sum(d,dim=2) ! Was triggering Unsupported CUDA data transfer
444+
end subroutine
445+
446+
! CHECK-LABEL: func.func @_QPsub23

0 commit comments

Comments
 (0)