Skip to content

Commit 9a0e03f

Browse files
authored
[flang][cuda] Update implicit data transfer for device component (#147882)
Update the detection of implicit data transfer when a device resident allocatable derived-type component is involved and remove the TODOs.
1 parent 0f3bdc3 commit 9a0e03f

File tree

4 files changed

+43
-20
lines changed

4 files changed

+43
-20
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 1 addition & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1359,24 +1359,7 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) {
13591359

13601360
/// Check if the expression is a mix of host and device variables that require
13611361
/// implicit data transfer.
1362-
inline bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
1363-
unsigned hostSymbols{0};
1364-
unsigned deviceSymbols{0};
1365-
for (const Symbol &sym : CollectCudaSymbols(expr)) {
1366-
if (IsCUDADeviceSymbol(sym)) {
1367-
++deviceSymbols;
1368-
} else {
1369-
if (sym.owner().IsDerivedType()) {
1370-
if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) {
1371-
++deviceSymbols;
1372-
}
1373-
}
1374-
++hostSymbols;
1375-
}
1376-
}
1377-
bool hasConstant{HasConstant(expr)};
1378-
return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0;
1379-
}
1362+
bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr);
13801363

13811364
// Checks whether the symbol on the LHS is present in the RHS expression.
13821365
bool CheckForSymbolMatch(const Expr<SomeType> *lhs, const Expr<SomeType> *rhs);

flang/lib/Evaluate/tools.cpp

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1154,6 +1154,31 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols(
11541154
template semantics::UnorderedSymbolSet CollectCudaSymbols(
11551155
const Expr<SubscriptInteger> &);
11561156

1157+
bool HasCUDAImplicitTransfer(const Expr<SomeType> &expr) {
1158+
semantics::UnorderedSymbolSet hostSymbols;
1159+
semantics::UnorderedSymbolSet deviceSymbols;
1160+
1161+
SymbolVector symbols{GetSymbolVector(expr)};
1162+
std::reverse(symbols.begin(), symbols.end());
1163+
bool skipNext{false};
1164+
for (const Symbol &sym : symbols) {
1165+
bool isComponent{sym.owner().IsDerivedType()};
1166+
bool skipComponent{false};
1167+
if (!skipNext) {
1168+
if (IsCUDADeviceSymbol(sym)) {
1169+
deviceSymbols.insert(sym);
1170+
} else if (isComponent) {
1171+
skipComponent = true; // Component is not device. Look on the base.
1172+
} else {
1173+
hostSymbols.insert(sym);
1174+
}
1175+
}
1176+
skipNext = isComponent && !skipComponent;
1177+
}
1178+
bool hasConstant{HasConstant(expr)};
1179+
return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0;
1180+
}
1181+
11571182
// HasVectorSubscript()
11581183
struct HasVectorSubscriptHelper
11591184
: public AnyTraverse<HasVectorSubscriptHelper, bool,

flang/lib/Lower/Bridge.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4842,8 +4842,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
48424842
.detailsIf<Fortran::semantics::ObjectEntityDetails>()) {
48434843
if (details->cudaDataAttr() &&
48444844
*details->cudaDataAttr() != Fortran::common::CUDADataAttr::Pinned) {
4845-
if (sym.owner().IsDerivedType() && IsAllocatable(sym.GetUltimate()))
4846-
TODO(loc, "Device resident allocatable derived-type component");
48474845
// TODO: This should probably being checked in semantic and give a
48484846
// proper error.
48494847
assert(

flang/test/Lower/CUDA/cuda-data-transfer.cuf

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ module mod1
77
integer :: i
88
end type
99

10+
type :: t2
11+
integer, device, allocatable, dimension(:) :: x
12+
end type
13+
1014
integer, device, dimension(11:20) :: cdev
1115

1216
contains
@@ -419,3 +423,16 @@ end subroutine
419423
! CHECK: fir.do_concurrent.loop
420424
! CHECK-NOT: cuf.data_transfer
421425
! CHECK: hlfir.assign
426+
427+
428+
subroutine sub22()
429+
use mod1
430+
type(t2) :: a
431+
integer :: b(100)
432+
allocate(a%x(100))
433+
434+
b = a%x
435+
end subroutine
436+
437+
! CHECK-LABEL: func.func @_QPsub22()
438+
! CHECK: cuf.data_transfer

0 commit comments

Comments
 (0)