@@ -702,6 +702,29 @@ static void instantiateGlobal(Fortran::lower::AbstractConverter &converter,
702
702
mapSymbolAttributes (converter, var, symMap, stmtCtx, cast);
703
703
}
704
704
705
+ bool needCUDAAlloc (const Fortran::semantics::Symbol &sym) {
706
+ if (Fortran::semantics::IsDummy (sym))
707
+ return false ;
708
+ if (const auto *details{
709
+ sym.GetUltimate ()
710
+ .detailsIf <Fortran::semantics::ObjectEntityDetails>()}) {
711
+ if (details->cudaDataAttr () &&
712
+ (*details->cudaDataAttr () == Fortran::common::CUDADataAttr::Device ||
713
+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Managed ||
714
+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Unified ||
715
+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Shared ||
716
+ *details->cudaDataAttr () == Fortran::common::CUDADataAttr::Pinned))
717
+ return true ;
718
+ const Fortran::semantics::DeclTypeSpec *type{details->type ()};
719
+ const Fortran::semantics::DerivedTypeSpec *derived{type ? type->AsDerived ()
720
+ : nullptr };
721
+ if (derived)
722
+ if (FindCUDADeviceAllocatableUltimateComponent (*derived))
723
+ return true ;
724
+ }
725
+ return false ;
726
+ }
727
+
705
728
// ===----------------------------------------------------------------===//
706
729
// Local variables instantiation (not for alias)
707
730
// ===----------------------------------------------------------------===//
@@ -732,7 +755,7 @@ static mlir::Value createNewLocal(Fortran::lower::AbstractConverter &converter,
732
755
if (ultimateSymbol.test (Fortran::semantics::Symbol::Flag::CrayPointee))
733
756
return builder.create <fir::ZeroOp>(loc, fir::ReferenceType::get (ty));
734
757
735
- if (Fortran::semantics::NeedCUDAAlloc (ultimateSymbol)) {
758
+ if (needCUDAAlloc (ultimateSymbol)) {
736
759
cuf::DataAttributeAttr dataAttr =
737
760
Fortran::lower::translateSymbolCUFDataAttribute (builder.getContext (),
738
761
ultimateSymbol);
@@ -1087,7 +1110,7 @@ static void instantiateLocal(Fortran::lower::AbstractConverter &converter,
1087
1110
Fortran::lower::defaultInitializeAtRuntime (converter, var.getSymbol (),
1088
1111
symMap);
1089
1112
auto *builder = &converter.getFirOpBuilder ();
1090
- if (Fortran::semantics::NeedCUDAAlloc (var.getSymbol ()) &&
1113
+ if (needCUDAAlloc (var.getSymbol ()) &&
1091
1114
!cuf::isCUDADeviceContext (builder->getRegion ())) {
1092
1115
cuf::DataAttributeAttr dataAttr =
1093
1116
Fortran::lower::translateSymbolCUFDataAttribute (builder->getContext (),
0 commit comments