Skip to content

Commit 9f83c4e

Browse files
authored
[flang][cuda] Allocate descriptor in managed memory on rebox block argument (llvm#123971)
Another case where the descriptor must be allocated with the CUF runtime and not a simple alloca instruction.
1 parent afcbcae commit 9f83c4e

File tree

2 files changed

+31
-18
lines changed

2 files changed

+31
-18
lines changed

flang/lib/Optimizer/CodeGen/CodeGen.cpp

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2040,19 +2040,20 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20402040
getBaseAddrFromBox(loc, inputBoxTyPair, loweredBox, rewriter);
20412041

20422042
if (!rebox.getSlice().empty() || !rebox.getSubcomponent().empty())
2043-
return sliceBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
2044-
operands, rewriter);
2045-
return reshapeBox(rebox, boxTy, dest, baseAddr, inputExtents, inputStrides,
2046-
operands, rewriter);
2043+
return sliceBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
2044+
inputStrides, operands, rewriter);
2045+
return reshapeBox(rebox, adaptor, boxTy, dest, baseAddr, inputExtents,
2046+
inputStrides, operands, rewriter);
20472047
}
20482048

20492049
private:
20502050
/// Write resulting shape and base address in descriptor, and replace rebox
20512051
/// op.
20522052
llvm::LogicalResult
2053-
finalizeRebox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
2054-
mlir::Value base, mlir::ValueRange lbounds,
2055-
mlir::ValueRange extents, mlir::ValueRange strides,
2053+
finalizeRebox(fir::cg::XReboxOp rebox, OpAdaptor adaptor,
2054+
mlir::Type destBoxTy, mlir::Value dest, mlir::Value base,
2055+
mlir::ValueRange lbounds, mlir::ValueRange extents,
2056+
mlir::ValueRange strides,
20562057
mlir::ConversionPatternRewriter &rewriter) const {
20572058
mlir::Location loc = rebox.getLoc();
20582059
mlir::Value zero =
@@ -2075,15 +2076,15 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
20752076
dest = insertBaseAddress(rewriter, loc, dest, base);
20762077
mlir::Value result = placeInMemoryIfNotGlobalInit(
20772078
rewriter, rebox.getLoc(), destBoxTy, dest,
2078-
isDeviceAllocation(rebox.getBox(), rebox.getBox()));
2079+
isDeviceAllocation(rebox.getBox(), adaptor.getBox()));
20792080
rewriter.replaceOp(rebox, result);
20802081
return mlir::success();
20812082
}
20822083

20832084
// Apply slice given the base address, extents and strides of the input box.
20842085
llvm::LogicalResult
2085-
sliceBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
2086-
mlir::Value base, mlir::ValueRange inputExtents,
2086+
sliceBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
2087+
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
20872088
mlir::ValueRange inputStrides, mlir::ValueRange operands,
20882089
mlir::ConversionPatternRewriter &rewriter) const {
20892090
mlir::Location loc = rebox.getLoc();
@@ -2109,7 +2110,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21092110
if (rebox.getSlice().empty())
21102111
// The array section is of the form array[%component][substring], keep
21112112
// the input array extents and strides.
2112-
return finalizeRebox(rebox, destBoxTy, dest, base,
2113+
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
21132114
/*lbounds*/ std::nullopt, inputExtents, inputStrides,
21142115
rewriter);
21152116

@@ -2158,15 +2159,16 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21582159
slicedStrides.emplace_back(stride);
21592160
}
21602161
}
2161-
return finalizeRebox(rebox, destBoxTy, dest, base, /*lbounds*/ std::nullopt,
2162-
slicedExtents, slicedStrides, rewriter);
2162+
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base,
2163+
/*lbounds*/ std::nullopt, slicedExtents, slicedStrides,
2164+
rewriter);
21632165
}
21642166

21652167
/// Apply a new shape to the data described by a box given the base address,
21662168
/// extents and strides of the box.
21672169
llvm::LogicalResult
2168-
reshapeBox(fir::cg::XReboxOp rebox, mlir::Type destBoxTy, mlir::Value dest,
2169-
mlir::Value base, mlir::ValueRange inputExtents,
2170+
reshapeBox(fir::cg::XReboxOp rebox, OpAdaptor adaptor, mlir::Type destBoxTy,
2171+
mlir::Value dest, mlir::Value base, mlir::ValueRange inputExtents,
21702172
mlir::ValueRange inputStrides, mlir::ValueRange operands,
21712173
mlir::ConversionPatternRewriter &rewriter) const {
21722174
mlir::ValueRange reboxShifts{
@@ -2175,7 +2177,7 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21752177
rebox.getShift().size()};
21762178
if (rebox.getShape().empty()) {
21772179
// Only setting new lower bounds.
2178-
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts,
2180+
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
21792181
inputExtents, inputStrides, rewriter);
21802182
}
21812183

@@ -2199,8 +2201,8 @@ struct XReboxOpConversion : public EmboxCommonConversion<fir::cg::XReboxOp> {
21992201
// nextStride = extent * stride;
22002202
stride = rewriter.create<mlir::LLVM::MulOp>(loc, idxTy, extent, stride);
22012203
}
2202-
return finalizeRebox(rebox, destBoxTy, dest, base, reboxShifts, newExtents,
2203-
newStrides, rewriter);
2204+
return finalizeRebox(rebox, adaptor, destBoxTy, dest, base, reboxShifts,
2205+
newExtents, newStrides, rewriter);
22042206
}
22052207

22062208
/// Return scalar element type of the input box.

flang/test/Fir/CUDA/cuda-code-gen.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,3 +187,14 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<!llvm.ptr<270> = dense<32> : vec
187187

188188
// CHECK-LABEL: llvm.func @_QPouter
189189
// CHECK: _FortranACUFAllocDescriptor
190+
191+
// -----
192+
193+
func.func @_QMm1Psub1(%arg0: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "da"}, %arg1: !fir.box<!fir.array<?xi32>> {cuf.data_attr = #cuf.cuda<device>, fir.bindc_name = "db"}, %arg2: !fir.ref<i32> {fir.bindc_name = "n"}) {
194+
%0 = fircg.ext_rebox %arg0 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
195+
%1 = fircg.ext_rebox %arg1 : (!fir.box<!fir.array<?xi32>>) -> !fir.box<!fir.array<?xi32>>
196+
return
197+
}
198+
199+
// CHECK-LABEL: llvm.func @_QMm1Psub1
200+
// CHECK-COUNT-2: _FortranACUFAllocDescriptor

0 commit comments

Comments
 (0)