Skip to content

Commit 925588c

Browse files
authored
[flang][cuda] Allocate derived-type with CUDA componement in managed memory (#146797)
Similarly to descriptor for device data, put derived type holding device descriptor in managed memory.
1 parent e873ce3 commit 925588c

File tree

3 files changed

+50
-6
lines changed

3 files changed

+50
-6
lines changed

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter,
702702
mapSymbolAttributes(converter, var, symMap, stmtCtx, cast);
703703
}
704704

705+
bool needCUDAAlloc(const Fortran::semantics::Symbol &sym) {
706+
if (Fortran::semantics::IsDummy(sym))
707+
return false;
708+
if (const auto *details{
709+
sym.GetUltimate()
710+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
711+
if (details->cudaDataAttr() &&
712+
(*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Device ||
713+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Managed ||
714+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Unified ||
715+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Shared ||
716+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Pinned))
717+
return true;
718+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
719+
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
720+
: nullptr};
721+
if (derived)
722+
if (FindCUDADeviceAllocatableUltimateComponent(*derived))
723+
return true;
724+
}
725+
return false;
726+
}
727+
705728
//===----------------------------------------------------------------===//
706729
// Local variables instantiation (not for alias)
707730
//===----------------------------------------------------------------===//
@@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
732755
if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee))
733756
return builder.create<fir::ZeroOp>(loc, fir::ReferenceType::get(ty));
734757

735-
if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) {
758+
if (needCUDAAlloc(ultimateSymbol)) {
736759
cuf::DataAttributeAttr dataAttr =
737760
Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(),
738761
ultimateSymbol);
@@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
10871110
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
10881111
symMap);
10891112
auto *builder = &converter.getFirOpBuilder();
1090-
if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol()) &&
1113+
if (needCUDAAlloc(var.getSymbol()) &&
10911114
!cuf::isCUDADeviceContext(builder->getRegion())) {
10921115
cuf::DataAttributeAttr dataAttr =
10931116
Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),

flang/lib/Semantics/tools.cpp

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,9 +1095,20 @@ bool IsDeviceAllocatable(const Symbol &symbol) {
10951095
}
10961096

10971097
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
1098-
const auto *object{
1098+
const auto *details{
10991099
symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
1100-
return object ? object->cudaDataAttr() : std::nullopt;
1100+
if (details) {
1101+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
1102+
const Fortran::semantics::DerivedTypeSpec *derived{
1103+
type ? type->AsDerived() : nullptr};
1104+
if (derived) {
1105+
if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
1106+
return common::CUDADataAttr::Managed;
1107+
}
1108+
}
1109+
return details->cudaDataAttr();
1110+
}
1111+
return std::nullopt;
11011112
}
11021113

11031114
bool IsAccessible(const Symbol &original, const Scope &scope) {

flang/test/Lower/CUDA/cuda-derived.cuf

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ module m1
77

88
type t1; real, device, allocatable :: a(:); end type
99
type t2; type(t1) :: b; end type
10+
contains
11+
subroutine sub1()
12+
type(ty_device) :: a
13+
end subroutine
14+
15+
! CHECK-LABEL: func.func @_QMm1Psub1()
16+
! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
17+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
18+
! CHECK: cuf.free %[[DECL]]#0 : !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>> {data_attr = #cuf.cuda<managed>}
19+
1020
end module
1121

1222
program main
@@ -16,5 +26,5 @@ program main
1626
end
1727

1828
! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
19-
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
20-
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}
29+
! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEa"}
30+
! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEb"}

0 commit comments

Comments
 (0)