diff --git a/flang/include/flang/Evaluate/tools.h b/flang/include/flang/Evaluate/tools.h index 18c244f6f450f..96ed86f468350 100644 --- a/flang/include/flang/Evaluate/tools.h +++ b/flang/include/flang/Evaluate/tools.h @@ -1359,24 +1359,7 @@ inline bool IsCUDADataTransfer(const A &lhs, const B &rhs) { /// Check if the expression is a mix of host and device variables that require /// implicit data transfer. -inline bool HasCUDAImplicitTransfer(const Expr &expr) { - unsigned hostSymbols{0}; - unsigned deviceSymbols{0}; - for (const Symbol &sym : CollectCudaSymbols(expr)) { - if (IsCUDADeviceSymbol(sym)) { - ++deviceSymbols; - } else { - if (sym.owner().IsDerivedType()) { - if (IsCUDADeviceSymbol(sym.owner().GetSymbol()->GetUltimate())) { - ++deviceSymbols; - } - } - ++hostSymbols; - } - } - bool hasConstant{HasConstant(expr)}; - return (hasConstant || (hostSymbols > 0)) && deviceSymbols > 0; -} +bool HasCUDAImplicitTransfer(const Expr &expr); // Checks whether the symbol on the LHS is present in the RHS expression. bool CheckForSymbolMatch(const Expr *lhs, const Expr *rhs); diff --git a/flang/lib/Evaluate/tools.cpp b/flang/lib/Evaluate/tools.cpp index 6a57d87a30e93..3d9f06308d8c1 100644 --- a/flang/lib/Evaluate/tools.cpp +++ b/flang/lib/Evaluate/tools.cpp @@ -1154,6 +1154,31 @@ template semantics::UnorderedSymbolSet CollectCudaSymbols( template semantics::UnorderedSymbolSet CollectCudaSymbols( const Expr &); +bool HasCUDAImplicitTransfer(const Expr &expr) { + semantics::UnorderedSymbolSet hostSymbols; + semantics::UnorderedSymbolSet deviceSymbols; + + SymbolVector symbols{GetSymbolVector(expr)}; + std::reverse(symbols.begin(), symbols.end()); + bool skipNext{false}; + for (const Symbol &sym : symbols) { + bool isComponent{sym.owner().IsDerivedType()}; + bool skipComponent{false}; + if (!skipNext) { + if (IsCUDADeviceSymbol(sym)) { + deviceSymbols.insert(sym); + } else if (isComponent) { + skipComponent = true; // Component is not device. Look on the base. + } else { + hostSymbols.insert(sym); + } + } + skipNext = isComponent && !skipComponent; + } + bool hasConstant{HasConstant(expr)}; + return (hasConstant || (hostSymbols.size() > 0)) && deviceSymbols.size() > 0; +} + // HasVectorSubscript() struct HasVectorSubscriptHelper : public AnyTraverse()) { if (details->cudaDataAttr() && *details->cudaDataAttr() != Fortran::common::CUDADataAttr::Pinned) { - if (sym.owner().IsDerivedType() && IsAllocatable(sym.GetUltimate())) - TODO(loc, "Device resident allocatable derived-type component"); // TODO: This should probably being checked in semantic and give a // proper error. assert( diff --git a/flang/test/Lower/CUDA/cuda-data-transfer.cuf b/flang/test/Lower/CUDA/cuda-data-transfer.cuf index 68a0202f951fe..3a9b55996d9b1 100644 --- a/flang/test/Lower/CUDA/cuda-data-transfer.cuf +++ b/flang/test/Lower/CUDA/cuda-data-transfer.cuf @@ -7,6 +7,10 @@ module mod1 integer :: i end type + type :: t2 + integer, device, allocatable, dimension(:) :: x + end type + integer, device, dimension(11:20) :: cdev contains @@ -419,3 +423,16 @@ end subroutine ! CHECK: fir.do_concurrent.loop ! CHECK-NOT: cuf.data_transfer ! CHECK: hlfir.assign + + +subroutine sub22() + use mod1 + type(t2) :: a + integer :: b(100) + allocate(a%x(100)) + + b = a%x +end subroutine + +! CHECK-LABEL: func.func @_QPsub22() +! CHECK: cuf.data_transfer