Skip to content

Commit 659c810

Browse files
authored
Reland [flang][cuda] Allocate derived-type with CUDA component in anaged memory (#147416)
1 parent 02f60fd commit 659c810

File tree

4 files changed

+56
-10
lines changed

4 files changed

+56
-10
lines changed

flang/include/flang/Semantics/tools.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,8 @@ DirectComponentIterator::const_iterator FindAllocatableOrPointerDirectComponent(
656656
const DerivedTypeSpec &);
657657
PotentialComponentIterator::const_iterator
658658
FindPolymorphicAllocatablePotentialComponent(const DerivedTypeSpec &);
659+
UltimateComponentIterator::const_iterator
660+
FindCUDADeviceAllocatableUltimateComponent(const DerivedTypeSpec &);
659661

660662
// The LabelEnforce class (given a set of labels) provides an error message if
661663
// there is a branch to a label which is not in the given set.

flang/lib/Lower/ConvertVariable.cpp

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter,
702702
mapSymbolAttributes(converter, var, symMap, stmtCtx, cast);
703703
}
704704

705+
bool needCUDAAlloc(const Fortran::semantics::Symbol &sym) {
706+
if (Fortran::semantics::IsDummy(sym))
707+
return false;
708+
if (const auto *details{
709+
sym.GetUltimate()
710+
.detailsIf<Fortran::semantics::ObjectEntityDetails>()}) {
711+
if (details->cudaDataAttr() &&
712+
(*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Device ||
713+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Managed ||
714+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Unified ||
715+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Shared ||
716+
*details->cudaDataAttr() == Fortran::common::CUDADataAttr::Pinned))
717+
return true;
718+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
719+
const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived()
720+
: nullptr};
721+
if (derived)
722+
if (FindCUDADeviceAllocatableUltimateComponent(*derived))
723+
return true;
724+
}
725+
return false;
726+
}
727+
705728
//===----------------------------------------------------------------===//
706729
// Local variables instantiation (not for alias)
707730
//===----------------------------------------------------------------===//
@@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
732755
if (ultimateSymbol.test(Fortran::semantics::Symbol::Flag::CrayPointee))
733756
return builder.create<fir::ZeroOp>(loc, fir::ReferenceType::get(ty));
734757

735-
if (Fortran::semantics::NeedCUDAAlloc(ultimateSymbol)) {
758+
if (needCUDAAlloc(ultimateSymbol)) {
736759
cuf::DataAttributeAttr dataAttr =
737760
Fortran::lower::translateSymbolCUFDataAttribute(builder.getContext(),
738761
ultimateSymbol);
@@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
10871110
Fortran::lower::defaultInitializeAtRuntime(converter, var.getSymbol(),
10881111
symMap);
10891112
auto *builder = &converter.getFirOpBuilder();
1090-
if (Fortran::semantics::NeedCUDAAlloc(var.getSymbol()) &&
1113+
if (needCUDAAlloc(var.getSymbol()) &&
10911114
!cuf::isCUDADeviceContext(builder->getRegion())) {
10921115
cuf::DataAttributeAttr dataAttr =
10931116
Fortran::lower::translateSymbolCUFDataAttribute(builder->getContext(),

flang/lib/Semantics/tools.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1081,12 +1081,6 @@ const Scope *FindCUDADeviceContext(const Scope *scope) {
10811081
});
10821082
}
10831083

1084-
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
1085-
const auto *object{
1086-
symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
1087-
return object ? object->cudaDataAttr() : std::nullopt;
1088-
}
1089-
10901084
bool IsDeviceAllocatable(const Symbol &symbol) {
10911085
if (IsAllocatable(symbol)) {
10921086
if (const auto *details{
@@ -1133,6 +1127,23 @@ bool CanCUDASymbolBeGlobal(const Symbol &sym) {
11331127
return true;
11341128
}
11351129

1130+
std::optional<common::CUDADataAttr> GetCUDADataAttr(const Symbol *symbol) {
1131+
const auto *details{
1132+
symbol ? symbol->detailsIf<ObjectEntityDetails>() : nullptr};
1133+
if (details) {
1134+
const Fortran::semantics::DeclTypeSpec *type{details->type()};
1135+
const Fortran::semantics::DerivedTypeSpec *derived{
1136+
type ? type->AsDerived() : nullptr};
1137+
if (derived) {
1138+
if (FindCUDADeviceAllocatableUltimateComponent(*derived)) {
1139+
return common::CUDADataAttr::Managed;
1140+
}
1141+
}
1142+
return details->cudaDataAttr();
1143+
}
1144+
return std::nullopt;
1145+
}
1146+
11361147
bool IsAccessible(const Symbol &original, const Scope &scope) {
11371148
const Symbol &ultimate{original.GetUltimate()};
11381149
if (ultimate.attrs().test(Attr::PRIVATE)) {

flang/test/Lower/CUDA/cuda-derived.cuf

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ module m1
77

88
type t1; real, device, allocatable :: a(:); end type
99
type t2; type(t1) :: b; end type
10+
contains
11+
subroutine sub1()
12+
type(ty_device) :: a
13+
end subroutine
14+
15+
! CHECK-LABEL: func.func @_QMm1Psub1()
16+
! CHECK: %[[ALLOC:.*]] = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} -> !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>
17+
! CHECK: %[[DECL:.*]]:2 = hlfir.declare %[[ALLOC]] {data_attr = #cuf.cuda<managed>, uniq_name = "_QMm1Fsub1Ea"} : (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>) -> (!fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>, !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>>)
18+
! CHECK: cuf.free %[[DECL]]#0 : !fir.ref<!fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}>> {data_attr = #cuf.cuda<managed>}
19+
1020
end module
1121

1222
program main
@@ -16,5 +26,5 @@ program main
1626
end
1727

1828
! CHECK-LABEL: func.func @_QQmain() attributes {fir.bindc_name = "main"}
19-
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", uniq_name = "_QFEa"}
20-
! CHECK: %{{.*}} = fir.alloca !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", uniq_name = "_QFEb"}
29+
! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tty_device{x:!fir.box<!fir.heap<!fir.array<?xi32>>>}> {bindc_name = "a", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEa"}
30+
! CHECK: %{{.*}} = cuf.alloc !fir.type<_QMm1Tt2{b:!fir.type<_QMm1Tt1{a:!fir.box<!fir.heap<!fir.array<?xf32>>>}>}> {bindc_name = "b", data_attr = #cuf.cuda<managed>, uniq_name = "_QFEb"}

0 commit comments

Comments
 (0)