Skip to content

Commit 7d76d9b

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
1 parent 00f5b69 commit 7d76d9b

File tree

6 files changed

+106
-50
lines changed

6 files changed

+106
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1276,6 +1276,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12761276
EVT OrigType = N->getValueType(0);
12771277
EVT EltVT = Mem->getMemoryVT();
12781278
unsigned NumElts = 1;
1279+
1280+
std::optional<unsigned> Opcode;
1281+
12791282
if (EltVT.isVector()) {
12801283
NumElts = EltVT.getVectorNumElements();
12811284
EltVT = EltVT.getVectorElementType();
@@ -1288,6 +1291,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12881291
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12891292
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
12901293
"NumElts must be divisible by the number of elts in subvectors");
1294+
if (N->getOpcode() == ISD::LOAD ||
1295+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1296+
switch (OrigType.getSimpleVT().SimpleTy) {
1297+
case MVT::v2f32:
1298+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1299+
: NVPTX::INT_PTX_LDU_GLOBAL_b64;
1300+
break;
1301+
case MVT::v2f16:
1302+
case MVT::v2bf16:
1303+
case MVT::v2i16:
1304+
case MVT::v4i8:
1305+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1306+
: NVPTX::INT_PTX_LDU_GLOBAL_b32;
1307+
break;
1308+
default:
1309+
llvm_unreachable("Unhandled packed vector type");
1310+
}
1311+
}
12911312
EltVT = OrigType;
12921313
NumElts /= OrigType.getVectorNumElements();
12931314
}
@@ -1309,50 +1330,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13091330
SelectADDR(Op1, Base, Offset);
13101331
SDValue Ops[] = {Base, Offset, Chain};
13111332

1312-
std::optional<unsigned> Opcode;
1313-
switch (N->getOpcode()) {
1314-
default:
1315-
return false;
1316-
case ISD::LOAD:
1317-
Opcode = pickOpcodeForVT(
1318-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1319-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1320-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1321-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1322-
break;
1323-
case ISD::INTRINSIC_W_CHAIN:
1324-
Opcode = pickOpcodeForVT(
1325-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1326-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1327-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1328-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1329-
break;
1330-
case NVPTXISD::LoadV2:
1331-
Opcode = pickOpcodeForVT(
1332-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1333-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1334-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1335-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1336-
break;
1337-
case NVPTXISD::LDUV2:
1338-
Opcode = pickOpcodeForVT(
1339-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1340-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1341-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1342-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1343-
break;
1344-
case NVPTXISD::LoadV4:
1345-
Opcode = pickOpcodeForVT(
1346-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1347-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1348-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1349-
break;
1350-
case NVPTXISD::LDUV4:
1351-
Opcode = pickOpcodeForVT(
1352-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1353-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1354-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1355-
break;
1333+
if (!Opcode) {
1334+
switch (N->getOpcode()) {
1335+
default:
1336+
return false;
1337+
case ISD::LOAD:
1338+
Opcode = pickOpcodeForVT(
1339+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1340+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1341+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1342+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1343+
break;
1344+
case ISD::INTRINSIC_W_CHAIN:
1345+
Opcode = pickOpcodeForVT(
1346+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1347+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1348+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1349+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1350+
break;
1351+
case NVPTXISD::LoadV2:
1352+
Opcode = pickOpcodeForVT(
1353+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1354+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1355+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1356+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1357+
break;
1358+
case NVPTXISD::LDUV2:
1359+
Opcode = pickOpcodeForVT(
1360+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1361+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1362+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1363+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1364+
break;
1365+
case NVPTXISD::LoadV4:
1366+
Opcode = pickOpcodeForVT(
1367+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1368+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1369+
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1370+
break;
1371+
case NVPTXISD::LDUV4:
1372+
Opcode = pickOpcodeForVT(
1373+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1374+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1375+
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1376+
break;
1377+
}
13561378
}
13571379
if (!Opcode)
13581380
return false;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,7 +2702,9 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
27022702
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
27032703
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
27042704
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2705+
def INT_PTX_LDU_GLOBAL_b32 : LDU_G<"b32", Int32Regs>;
27052706
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2707+
def INT_PTX_LDU_GLOBAL_b64 : LDU_G<"b64", Int64Regs>;
27062708
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
27072709
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
27082710

@@ -2752,7 +2754,9 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
27522754
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
27532755
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
27542756
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2757+
def INT_PTX_LDG_GLOBAL_b32 : LDG_G<"b32", Int32Regs>;
27552758
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2759+
def INT_PTX_LDG_GLOBAL_b64 : LDG_G<"b64", Int64Regs>;
27562760
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
27572761
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
27582762

llvm/test/CodeGen/NVPTX/ldg-invariant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ define half @ld_global_v2f16(ptr addrspace(1) %ptr) {
3232
; CHECK-EMPTY:
3333
; CHECK-NEXT: // %bb.0:
3434
; CHECK-NEXT: ld.param.u64 %rd1, [ld_global_v2f16_param_0];
35-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
35+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
3636
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
3737
; CHECK-NEXT: cvt.f32.f16 %f1, %rs2;
3838
; CHECK-NEXT: cvt.f32.f16 %f2, %rs1;

llvm/test/CodeGen/NVPTX/ldu-ldg.ll

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ declare float @llvm.nvvm.ldu.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
1212
declare double @llvm.nvvm.ldu.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
1313
declare half @llvm.nvvm.ldu.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
1414
declare <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
15+
declare <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
1516

1617
declare i8 @llvm.nvvm.ldg.global.i.i8.p1(ptr addrspace(1) %ptr, i32 %align)
1718
declare i16 @llvm.nvvm.ldg.global.i.i16.p1(ptr addrspace(1) %ptr, i32 %align)
@@ -22,6 +23,7 @@ declare float @llvm.nvvm.ldg.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
2223
declare double @llvm.nvvm.ldg.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
2324
declare half @llvm.nvvm.ldg.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
2425
declare <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
26+
declare <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
2527

2628
define i8 @test_ldu_i8(ptr addrspace(1) %ptr) {
2729
; CHECK-LABEL: test_ldu_i8(
@@ -154,13 +156,27 @@ define <2 x half> @test_ldu_v2f16(ptr addrspace(1) %ptr) {
154156
; CHECK-EMPTY:
155157
; CHECK-NEXT: // %bb.0:
156158
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f16_param_0];
157-
; CHECK-NEXT: ldu.global.u32 %r1, [%rd1];
159+
; CHECK-NEXT: ldu.global.b32 %r1, [%rd1];
158160
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
159161
; CHECK-NEXT: ret;
160162
%val = tail call <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
161163
ret <2 x half> %val
162164
}
163165

166+
define <2 x float> @test_ldu_v2f32(ptr addrspace(1) %ptr) {
167+
; CHECK-LABEL: test_ldu_v2f32(
168+
; CHECK: {
169+
; CHECK-NEXT: .reg .b64 %rd<3>;
170+
; CHECK-EMPTY:
171+
; CHECK-NEXT: // %bb.0:
172+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f32_param_0];
173+
; CHECK-NEXT: ldu.global.b64 %rd2, [%rd1];
174+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
175+
; CHECK-NEXT: ret;
176+
%val = tail call <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
177+
ret <2 x float> %val
178+
}
179+
164180
define i8 @test_ldg_i8(ptr addrspace(1) %ptr) {
165181
; CHECK-LABEL: test_ldg_i8(
166182
; CHECK: {
@@ -291,13 +307,27 @@ define <2 x half> @test_ldg_v2f16(ptr addrspace(1) %ptr) {
291307
; CHECK-EMPTY:
292308
; CHECK-NEXT: // %bb.0:
293309
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f16_param_0];
294-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
310+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
295311
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
296312
; CHECK-NEXT: ret;
297313
%val = tail call <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
298314
ret <2 x half> %val
299315
}
300316

317+
define <2 x float> @test_ldg_v2f32(ptr addrspace(1) %ptr) {
318+
; CHECK-LABEL: test_ldg_v2f32(
319+
; CHECK: {
320+
; CHECK-NEXT: .reg .b64 %rd<3>;
321+
; CHECK-EMPTY:
322+
; CHECK-NEXT: // %bb.0:
323+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f32_param_0];
324+
; CHECK-NEXT: ld.global.nc.b64 %rd2, [%rd1];
325+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
326+
; CHECK-NEXT: ret;
327+
%val = tail call <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
328+
ret <2 x float> %val
329+
}
330+
301331
@g = addrspace(1) global i32 0
302332

303333
define i32 @test_ldg_asi() {

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ define ptx_kernel void @foo7(ptr noalias readonly %from, ptr %to) {
8282
; SM20-LABEL: .visible .entry foo8(
8383
; SM20: ld.global.u32
8484
; SM35-LABEL: .visible .entry foo8(
85-
; SM35: ld.global.nc.u32
85+
; SM35: ld.global.nc.b32
8686
define ptx_kernel void @foo8(ptr noalias readonly %from, ptr %to) {
8787
%1 = load <2 x i16>, ptr %from
8888
store <2 x i16> %1, ptr %to
@@ -132,7 +132,7 @@ define ptx_kernel void @foo12(ptr noalias readonly %from, ptr %to) {
132132
; SM20-LABEL: .visible .entry foo13(
133133
; SM20: ld.global.u32
134134
; SM35-LABEL: .visible .entry foo13(
135-
; SM35: ld.global.nc.u32
135+
; SM35: ld.global.nc.b32
136136
define ptx_kernel void @foo13(ptr noalias readonly %from, ptr %to) {
137137
%1 = load <4 x i8>, ptr %from
138138
store <4 x i8> %1, ptr %to

llvm/test/CodeGen/NVPTX/read-global-variable-constant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ define float @test_gv_float() {
1717

1818
; CHECK-LABEL: test_gv_float2()
1919
define <2 x float> @test_gv_float2() {
20-
; CHECK: ld.global.nc.v2.f32
20+
; CHECK: ld.global.nc.b64
2121
%v = load <2 x float>, ptr @gv_float2
2222
ret <2 x float> %v
2323
}

0 commit comments

Comments
 (0)