Skip to content

Commit 668c964

Browse files
authored
[AMDGPU] [MLIR] Add 96 and 128 bit GatherToLDS for gfx950 (#147496)
This PR adds 96 and 128 gather_to_lds support for gfx950. Updating lowering, verifier and tests.
1 parent c223521 commit 668c964

File tree

5 files changed

+107
-11
lines changed

5 files changed

+107
-11
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,18 +1196,23 @@ struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern<GatherToLDSOp> {
11961196
// augment it to transfer multiple elements per thread by issuing multiple
11971197
// `global_load_lds` instructions.
11981198
Type transferType = op.getTransferType();
1199-
size_t loadWidth = [&]() -> size_t {
1199+
int loadWidth = [&]() -> int {
12001200
if (auto transferVectorType = dyn_cast<VectorType>(transferType)) {
1201-
return transferVectorType.getNumElements() *
1202-
(transferVectorType.getElementTypeBitWidth() / 8);
1201+
return (transferVectorType.getNumElements() *
1202+
transferVectorType.getElementTypeBitWidth()) /
1203+
8;
12031204
}
12041205
return transferType.getIntOrFloatBitWidth() / 8;
12051206
}();
12061207

1207-
// Currently only 1, 2, and 4 byte loads are supported.
1208-
if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4)
1208+
// Currently only 1, 2, 4, 12 and 16 byte loads are supported.
1209+
if (!llvm::is_contained({1, 2, 4, 12, 16}, loadWidth))
12091210
return op.emitOpError("chipset unsupported element size");
12101211

1212+
if (chipset != kGfx950 && llvm::is_contained({12, 16}, loadWidth))
1213+
return op.emitOpError("Gather to LDS instructions with 12-byte and "
1214+
"16-byte load widths are only supported on gfx950");
1215+
12111216
Value srcPtr =
12121217
getStridedElementPtr(rewriter, loc, srcMemRefType, adaptor.getSrc(),
12131218
(adaptor.getSrcIndices()));

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -502,17 +502,18 @@ LogicalResult GatherToLDSOp::verify() {
502502
if (elemType != dstType.getElementType())
503503
return emitOpError("source and destination element types must match");
504504

505-
// copy type sizes should be 1, 2, or 4 bytes.
505+
// copy type sizes should be 1, 2, 4, 12 or 16 bytes.
506506
auto transferType = getTransferType();
507-
size_t transferSize;
507+
int transferSize;
508508
if (auto vectorTransfer = dyn_cast<VectorType>(transferType)) {
509509
transferSize = vectorTransfer.getNumElements() *
510510
vectorTransfer.getElementTypeBitWidth();
511511
} else {
512512
transferSize = transferType.getIntOrFloatBitWidth();
513513
}
514-
if (transferSize != 8 && transferSize != 16 && transferSize != 32)
515-
return emitOpError("Transfering type size must be 8, 16, or 32 bits");
514+
if (!llvm::is_contained({8, 16, 32, 96, 128}, transferSize))
515+
return emitOpError(
516+
"Transfering type size must be 8, 16, 32, 96 or 128 bits");
516517

517518
if (!hasGlobalMemorySpace(srcType.getMemorySpace()) &&
518519
!hasFatRawBufferMemorySpace(srcType.getMemorySpace()))
Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx942 2>&1 | FileCheck %s --check-prefix=GFX942
2+
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s --check-prefix=GFX950
3+
4+
#gpu_global_addrspace = 1
5+
#gpu_lds_addrspace = 3
6+
#amdgpu_fat_buffer_addrspace = 7
7+
8+
// GFX950-LABEL: func @fat_buffer_load_to_rocdl_f96
9+
// GFX950-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 7>)
10+
func.func @fat_buffer_load_to_rocdl_f96(%global : memref<128x72xf32, #amdgpu_fat_buffer_addrspace>) {
11+
%c0 = arith.constant 0 : index
12+
%c12 = arith.constant 12 : index
13+
%c32 = arith.constant 32 : index
14+
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
15+
// GFX950: %[[BUFFER_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
16+
17+
// GFX950: %[[C0:.*]] = arith.constant 0 : index
18+
// GFX950: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
19+
// GFX950: %[[C12:.*]] = arith.constant 12 : index
20+
// GFX950: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
21+
// GFX950: %[[C32:.*]] = arith.constant 32 : index
22+
// GFX950: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
23+
24+
// GFX950: %[[ALLOC:.*]] = memref.alloc()
25+
// GFX950: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
26+
// GFX950: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[BUFFER_DESC]][1]
27+
28+
// GFX950: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
29+
// GFX950: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
30+
// GFX950: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
31+
32+
// GFX950: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
33+
// GFX950: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
34+
35+
// GFX950: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
36+
// GFX950: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
37+
// GFX950: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
38+
39+
// GFX950: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
40+
// GFX950: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 12
41+
// GFX942: error: 'amdgpu.gather_to_lds' op Gather to LDS instructions with 12-byte and 16-byte load widths are only supported on gfx950
42+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
43+
: vector<16xf6E3M2FN>, memref<128x72xf32, #amdgpu_fat_buffer_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
44+
func.return
45+
}
46+
47+
// -----
48+
49+
#gpu_global_addrspace = 1
50+
#gpu_lds_addrspace = 3
51+
#amdgpu_fat_buffer_addrspace = 7
52+
53+
// GFX950-LABEL: func @fat_buffer_load_to_rocdl_f128
54+
// GFX950-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 7>)
55+
func.func @fat_buffer_load_to_rocdl_f128(%global : memref<128x72xf32, #amdgpu_fat_buffer_addrspace>) {
56+
%c0 = arith.constant 0 : index
57+
%c12 = arith.constant 12 : index
58+
%c32 = arith.constant 32 : index
59+
%alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace>
60+
// GFX950: %[[BUFFER_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]]
61+
62+
// GFX950: %[[C0:.*]] = arith.constant 0 : index
63+
// GFX950: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64
64+
// GFX950: %[[C12:.*]] = arith.constant 12 : index
65+
// GFX950: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]]
66+
// GFX950: %[[C32:.*]] = arith.constant 32 : index
67+
// GFX950: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]]
68+
69+
// GFX950: %[[ALLOC:.*]] = memref.alloc()
70+
// GFX950: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast
71+
// GFX950: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[BUFFER_DESC]][1]
72+
73+
// GFX950: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64
74+
// GFX950: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64
75+
// GFX950: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64
76+
77+
// GFX950: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]]
78+
// GFX950: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1]
79+
80+
// GFX950: %[[C64:.*]] = llvm.mlir.constant(64 : index) : i64
81+
// GFX950: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C64]] : i64
82+
// GFX950: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64
83+
84+
// GFX950: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]]
85+
// GFX950: rocdl.load.to.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], 16
86+
// GFX942: error: 'amdgpu.gather_to_lds' op Gather to LDS instructions with 12-byte and 16-byte load widths are only supported on gfx950
87+
amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0]
88+
: f128, memref<128x72xf32, #amdgpu_fat_buffer_addrspace>, memref<64x64xf32, #gpu_lds_addrspace>
89+
func.return
90+
}

mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s
2+
// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
23

34
#gpu_global_addrspace = 1
45
#gpu_lds_addrspace = 3
@@ -118,7 +119,6 @@ func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_add
118119
func.return
119120
}
120121

121-
122122
// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices
123123
// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index)
124124
func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) {

mlir/test/Conversion/AMDGPUToROCDL/transpose_load.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
// RUN: mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx950 | FileCheck %s
2-
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx945 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
2+
// RUN: not mlir-opt %s --split-input-file -convert-amdgpu-to-rocdl=chipset=gfx942 2>&1 | FileCheck %s --check-prefix=CHECK-OLD
33

44
// CHECK-LABEL: func @transpose_load_to_rocdl_4xf16
55
func.func @transpose_load_to_rocdl_4xf16(%idx1 : index, %idx2 : index, %wgmem : memref<128x72xf16, 3>) -> vector<4xf16> {

0 commit comments

Comments
 (0)