Skip to content

Commit c26e1a2

Browse files
authored
[flang][cuda] Allocate descriptor in managed memory when memref is a block argument (llvm#123829)
1 parent e45de3d commit c26e1a2

File tree

2 files changed

+46
-8
lines changed

2 files changed

+46
-8
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,15 +1725,35 @@ struct EmboxOpConversion : public EmboxCommonConversion<fir::EmboxOp> {
17251725
}
17261726
};
17271727

1728-
static bool isDeviceAllocation(mlir::Value val) {
1728+
static bool isDeviceAllocation(mlir::Value val, mlir::Value adaptorVal) {
17291729
if (auto loadOp = mlir::dyn_cast_or_null<fir::LoadOp>(val.getDefiningOp()))
1730-
return isDeviceAllocation(loadOp.getMemref());
1730+
return isDeviceAllocation(loadOp.getMemref(), {});
17311731
if (auto boxAddrOp =
17321732
mlir::dyn_cast_or_null<fir::BoxAddrOp>(val.getDefiningOp()))
1733-
return isDeviceAllocation(boxAddrOp.getVal());
1733+
return isDeviceAllocation(boxAddrOp.getVal(), {});
17341734
if (auto convertOp =
17351735
mlir::dyn_cast_or_null<fir::ConvertOp>(val.getDefiningOp()))
1736-
return isDeviceAllocation(convertOp.getValue());
1736+
return isDeviceAllocation(convertOp.getValue(), {});
1737+
if (!val.getDefiningOp() && adaptorVal) {
1738+
if (auto blockArg = llvm::cast<mlir::BlockArgument>(adaptorVal)) {
1739+
if (blockArg.getOwner() && blockArg.getOwner()->getParentOp() &&
1740+
blockArg.getOwner()->isEntryBlock()) {
1741+
if (auto func = mlir::dyn_cast_or_null<mlir::FunctionOpInterface>(
1742+
*blockArg.getOwner()->getParentOp())) {
1743+
auto argAttrs = func.getArgAttrs(blockArg.getArgNumber());
1744+
for (auto attr : argAttrs) {
1745+
if (attr.getName().getValue().ends_with(cuf::getDataAttrName())) {
1746+
auto dataAttr =
1747+
mlir::dyn_cast<cuf::DataAttributeAttr>(attr.getValue());
1748+
if (dataAttr.getValue() != cuf::DataAttribute::Pinned &&
1749+
dataAttr.getValue() != cuf::DataAttribute::Unified)
1750+
return true;
1751+
}
1752+
}
1753+
}
1754+
}
1755+
}
1756+
}
17371757
if (auto callOp = mlir::dyn_cast_or_null<fir::CallOp>(val.getDefiningOp()))
17381758
if (callOp.getCallee() &&
17391759
(callOp.getCallee().value().getRootReference().getValue().starts_with(
@@ -1928,7 +1948,8 @@ struct XEmboxOpConversion : public EmboxCommonConversion<fir::cg::XEmboxOp> {
19281948
if (fir::isDerivedTypeWithLenParams(boxTy))
19291949
TODO(loc, "fir.embox codegen of derived with length parameters");
19301950
mlir::Value result = placeInMemoryIfNotGlobalInit(
1931-
rewriter, loc, boxTy, dest, isDeviceAllocation(xbox.getMemref()));
1951+
rewriter, loc, boxTy, dest,
1952+
isDeviceAllocation(xbox.getMemref(), adaptor.getMemref()));
19321953
rewriter.replaceOp(xbox, result);
19331954
return mlir::success();
19341955
}
@@ -2052,9 +2073,9 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20522073
dest = insertStride(rewriter, loc, dest, dim, std::get<1>(iter.value()));
20532074
}
20542075
dest = insertBaseAddress(rewriter, loc, dest, base);
2055-
mlir::Value result =
2056-
placeInMemoryIfNotGlobalInit(rewriter, rebox.getLoc(), destBoxTy, dest,
2057-
isDeviceAllocation(rebox.getBox()));
2076+
mlir::Value result = placeInMemoryIfNotGlobalInit(
2077+
rewriter, rebox.getLoc(), destBoxTy, dest,
2078+
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
20582079
rewriter.replaceOp(rebox, result);
20592080
return mlir::success();
20602081
}

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,20 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
170170

171171
// CHECK-LABEL: llvm.func @_QQmain()
172172
// CHECK-COUNT-3: llvm.call @_FortranACUFAllocDescriptor
173+
174+
// -----
175+
176+
module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vector<4xi64>, f128 = dense<128> : vector<2xi64>, f64 = dense<64> : vector<2xi64>, f16 = dense<16> : vector<2xi64>, i32 = dense<32> : vector<2xi64>, i64 = dense<64> : vector<2xi64>, !llvm.ptr<272> = dense<64> : vector<4xi64>, !llvm.ptr<271> = dense<32> : vector<4xi64>, f80 = dense<128> : vector<2xi64>, i128 = dense<128> : vector<2xi64>, i16 = dense<16> : vector<2xi64>, i8 = dense<8> : vector<2xi64>, !llvm.ptr = dense<64> : vector<4xi64>, i1 = dense<8> : vector<2xi64>, "dlti.endianness" = "little", "dlti.stack_alignment" = 128 : i64>, fir.defaultkind = "a1c4d8i4l4r4", fir.kindmap = "", gpu.container_module, llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-i128:128-f80:128-n8:16:32:64-S128", llvm.ident = "flang version 20.0.0 (git@github.com:clementval/llvm-project.git efc2415bcce8e8a9e73e77aa122c8aba1c1fbbd2)", llvm.target_triple = "x86_64-unknown-linux-gnu"} {
177+
func.func @_QPouter(%arg0: !fir.ref<!fir.array<100x100xf64>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "a"}) {
178+
%c0_i32 = arith.constant 0 : i32
179+
%c100 = arith.constant 100 : index
180+
%0 = fir.alloca tuple<!fir.box<!fir.array<100x100xf64>>>
181+
%1 = fir.coordinate_of %0, %c0_i32 : (!fir.ref<tuple<!fir.box<!fir.array<100x100xf64>>>>, i32) -> !fir.ref<!fir.box<!fir.array<100x100xf64>>>
182+
%2 = fircg.ext_embox %arg0(%c100, %c100) : (!fir.ref<!fir.array<100x100xf64>>, index, index) -> !fir.box<!fir.array<100x100xf64>>
183+
fir.store %2 to %1 : !fir.ref<!fir.box<!fir.array<100x100xf64>>>
184+
return
185+
}
186+
}
187+
188+
// CHECK-LABEL: llvm.func @_QPouter
189+
// CHECK: _FortranACUFAllocDescriptor

0 commit comments

Comments
 (0)