Skip to content

Commit c13ebb5

Browse files
Fix bug in gpu.memcpy lowering for dynamically shaped operands. (#128820)
Compute the number of elements to be copied by multiplying dim sizes along all the dimensions.
1 parent 9db72e5 commit c13ebb5

File tree

2 files changed

+30
-8
lines changed

2 files changed

+30
-8
lines changed

mlir/lib/Conversion/GPUCommon/GPUToLLVMConversion.cpp

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,14 +76,16 @@ class ConvertOpToGpuRuntimeCallPattern : public ConvertOpToLLVMPattern<OpTy> {
7676
Value getNumElements(ConversionPatternRewriter &rewriter, Location loc,
7777
MemRefType type, MemRefDescriptor desc) const {
7878
Type indexType = ConvertToLLVMPattern::getIndexType();
79-
return type.hasStaticShape()
80-
? ConvertToLLVMPattern::createIndexAttrConstant(
81-
rewriter, loc, indexType, type.getNumElements())
82-
// For identity maps (verified by caller), the number of
83-
// elements is stride[0] * size[0].
84-
: rewriter.create<LLVM::MulOp>(loc,
85-
desc.stride(rewriter, loc, 0),
86-
desc.size(rewriter, loc, 0));
79+
if (type.hasStaticShape())
80+
return ConvertToLLVMPattern::createIndexAttrConstant(
81+
rewriter, loc, indexType, type.getNumElements());
82+
// Compute the number of elements by multiplying all the dim sizes.
83+
uint64_t rank = type.getRank();
84+
Value numElements = desc.size(rewriter, loc, /*pos=*/0);
85+
for (unsigned i = 1; i < rank; i++)
86+
numElements = rewriter.create<LLVM::MulOp>(
87+
loc, numElements, desc.size(rewriter, loc, /*pos=*/i));
88+
return numElements;
8789
}
8890

8991
MLIRContext *context = &this->getTypeConverter()->getContext();

mlir/test/Conversion/GPUCommon/lower-memcpy-to-gpu-runtime-calls.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,23 @@ module attributes {gpu.container_module} {
1717
return
1818
}
1919
}
20+
21+
// -----
22+
23+
module attributes {gpu.container_module} {
24+
25+
// CHECK: func @dynamic
26+
func.func @dynamic(%dst : memref<?x?xf32, 1>, %src : memref<?x?xf32>) {
27+
// CHECK: %[[T0:.*]] = llvm.call @mgpuStreamCreate
28+
%t0 = gpu.wait async
29+
%t1 = gpu.memcpy async [%t0] %dst, %src : memref<?x?xf32, 1>, memref<?x?xf32>
30+
// CHECK: %[[DIM_SIZE_0:.*]] = llvm.extractvalue %{{.*}}[3, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
31+
// CHECK-NEXT: %[[DIM_SIZE_1:.*]] = llvm.extractvalue %{{.*}}[3, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
32+
// CHECK: %[[NUM_ELEMENTS:.*]] = llvm.mul %[[DIM_SIZE_0]], %[[DIM_SIZE_1]] : i64
33+
// CHECK: %[[SIZE_PTR:.*]] = llvm.getelementptr %{{.*}}[%[[NUM_ELEMENTS]]] : (!llvm.ptr, i64) -> !llvm.ptr, f32
34+
// CHECK-NEXT: %[[SIZE_INT:.*]] = llvm.ptrtoint %[[SIZE_PTR]] : !llvm.ptr to i64
35+
// CHECK: %[[ADDR_CAST:.*]] = llvm.addrspacecast
36+
// CHECK: llvm.call @mgpuMemcpy(%[[ADDR_CAST]], %{{.*}}, %[[SIZE_INT]], %[[T0]])
37+
return
38+
}
39+
}

0 commit comments

Comments
 (0)