Skip to content

Commit e873ce3

Browse files
authored
[flang][cuda] Do not create global for derived-type with allocatable device components (#146780)
derived type with CUDA device allocatable components will be handle via CUDA allocation. Do not create global for them.
1 parent 0d7e64f commit e873ce3

File tree

5 files changed

+67
-16
lines changed

5 files changed

+67
-16
lines changed

flang/include/flang/Evaluate/tools.h

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1286,16 +1286,7 @@ bool CheckForCoindexedObject(parser::ContextualMessages &,
12861286
const std::optional<ActualArgument> &, const std::string &procName,
12871287
const std::string &argName);
12881288

1289-
inline bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) {
1290-
if (const auto *details =
1291-
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()) {
1292-
if (details->cudaDataAttr() &&
1293-
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
1294-
return false;
1295-
}
1296-
}
1297-
return true;
1298-
}
1289+
bool CanCUDASymbolHaveSaveAttr(const Symbol &sym);
12991290

13001291
inline bool IsCUDADeviceSymbol(const Symbol &sym) {
13011292
if (const auto *details =

flang/include/flang/Semantics/tools.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -654,6 +654,8 @@ DirectComponentIterator::const_iterator FindAllocatableOrPointerDirectComponent(
654654
const DerivedTypeSpec &);
655655
PotentialComponentIterator::const_iterator
656656
FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &);
657+
UltimateComponentIterator::const_iterator
658+
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &);
657659

658660
// The LabelEnforce class (given a set of labels) provides an error message if
659661
// there is a branch to a label which is not in the given set.

flang/lib/Evaluate/tools.cpp

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2173,6 +2173,25 @@ bool IsAutomatic(const Symbol &original) {
21732173
return false;
21742174
}
21752175

2176+
bool CanCUDASymbolHaveSaveAttr(const Symbol &sym) {
2177+
if (const auto *details{
2178+
sym.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
2179+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
2180+
const Fortran::semantics::DerivedTypeSpec *derived{
2181+
type ? type->AsDerived() : nullptr};
2182+
if (derived) {
2183+
if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
2184+
return false;
2185+
}
2186+
}
2187+
if (details->cudaDataAttr() &&
2188+
*details->cudaDataAttr() != common::CUDADataAttr::Unified) {
2189+
return false;
2190+
}
2191+
}
2192+
return true;
2193+
}
2194+
21762195
bool IsSaved(const Symbol &original) {
21772196
const Symbol &symbol{GetAssociationRoot(original)};
21782197
const Scope &scope{symbol.owner()};
@@ -2195,7 +2214,7 @@ bool IsSaved(const Symbol &original) {
21952214
} else if (scopeKind == Scope::Kind::Module ||
21962215
(scopeKind == Scope::Kind::MainProgram &&
21972216
(symbol.attrs().test(Attr::TARGET) || evaluate::IsCoarray(symbol)) &&
2198-
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol))) {
2217+
CanCUDASymbolHaveSaveAttr(symbol))) {
21992218
// 8.5.16p4
22002219
// In main programs, implied SAVE matters only for pointer
22012220
// initialization targets and coarrays.
@@ -2205,7 +2224,7 @@ bool IsSaved(const Symbol &original) {
22052224
(features.IsEnabled(
22062225
common::LanguageFeature::SaveBigMainProgramVariables) &&
22072226
symbol.size() > 32)) &&
2208-
Fortran::evaluate::CanCUDASymbolHaveSaveAttr(symbol)) {
2227+
CanCUDASymbolHaveSaveAttr(symbol)) {
22092228
// With SaveBigMainProgramVariables, keeping all unsaved main program
22102229
// variables of 32 bytes or less on the stack allows keeping numerical and
22112230
// logical scalars, small scalar characters or derived, small arrays, and
@@ -2223,15 +2242,15 @@ bool IsSaved(const Symbol &original) {
22232242
} else if (symbol.test(Symbol::Flag::InDataStmt)) {
22242243
return true;
22252244
} else if (const auto *object{symbol.detailsIf<ObjectEntityDetails>()};
2226-
object && object->init()) {
2245+
object && object->init()) {
22272246
return true;
22282247
} else if (IsProcedurePointer(symbol) && symbol.has<ProcEntityDetails>() &&
22292248
symbol.get<ProcEntityDetails>().init()) {
22302249
return true;
22312250
} else if (scope.hasSAVE()) {
22322251
return true; // bare SAVE statement
2233-
} else if (const Symbol * block{FindCommonBlockContaining(symbol)};
2234-
block && block->attrs().test(Attr::SAVE)) {
2252+
} else if (const Symbol *block{FindCommonBlockContaining(symbol)};
2253+
block && block->attrs().test(Attr::SAVE)) {
22352254
return true; // in COMMON with SAVE
22362255
} else {
22372256
return false;

flang/lib/Semantics/tools.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,6 +1081,19 @@ const Scope *FindCUDADeviceContext(const Scope *scope) {
10811081
});
10821082
}
10831083

1084+
bool IsDeviceAllocatable(const Symbol &symbol) {
1085+
if (IsAllocatable(symbol)) {
1086+
if (const auto *details{
1087+
symbol.GetUltimate().detailsIf<semantics::ObjectEntityDetails>()}) {
1088+
if (details->cudaDataAttr() &&
1089+
*details->cudaDataAttr() != common::CUDADataAttr::Pinned) {
1090+
return true;
1091+
}
1092+
}
1093+
}
1094+
return false;
1095+
}
1096+
10841097
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
10851098
const auto *object{
10861099
symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
@@ -1426,6 +1439,12 @@ FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &derived) {
14261439
potentials.begin(), potentials.end(), IsPolymorphicAllocatable);
14271440
}
14281441

1442+
UltimateComponentIterator::const_iterator
1443+
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &derived) {
1444+
UltimateComponentIterator ultimates{derived};
1445+
return std::find_if(ultimates.begin(), ultimates.end(), IsDeviceAllocatable);
1446+
}
1447+
14291448
const Symbol *FindUltimateComponent(const DerivedTypeSpec &derived,
14301449
const std::function<bool(const Symbol &)> &predicate) {
14311450
UltimateComponentIterator ultimates{derived};
@@ -1788,4 +1807,4 @@ bool HadUseError(
17881807
}
17891808
}
17901809

1791-
} // namespace Fortran::semantics
1810+
} // namespace Fortran::semantics
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
! RUN: bbc -emit-hlfir -fcuda %s -o - | FileCheck %s
2+
3+
module m1
4+
type ty_device
5+
integer, device, allocatable, dimension(:) :: x
6+
end type
7+
8+
type t1; real, device, allocatable :: a(:); end type
9+
type t2; type(t1) :: b; end type
10+
end module
11+
12+
program main
13+
use m1
14+
type(ty_device) :: a
15+
type(t2) :: b
16+
end
17+
18+
! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
19+
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
20+
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}

0 commit comments

Comments
 (0)