Skip to content

Commit bfdca9c

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
1 parent 99938fa commit bfdca9c

File tree

6 files changed

+106
-50
lines changed

6 files changed

+106
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1290,6 +1290,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12901290
EVT OrigType = N->getValueType(0);
12911291
EVT EltVT = Mem->getMemoryVT();
12921292
unsigned NumElts = 1;
1293+
1294+
std::optional<unsigned> Opcode;
1295+
12931296
if (EltVT.isVector()) {
12941297
NumElts = EltVT.getVectorNumElements();
12951298
EltVT = EltVT.getVectorElementType();
@@ -1302,6 +1305,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13021305
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
13031306
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
13041307
"NumElts must be divisible by the number of elts in subvectors");
1308+
if (N->getOpcode() == ISD::LOAD ||
1309+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1310+
switch (OrigType.getSimpleVT().SimpleTy) {
1311+
case MVT::v2f32:
1312+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1313+
: NVPTX::INT_PTX_LDU_GLOBAL_b64;
1314+
break;
1315+
case MVT::v2f16:
1316+
case MVT::v2bf16:
1317+
case MVT::v2i16:
1318+
case MVT::v4i8:
1319+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1320+
: NVPTX::INT_PTX_LDU_GLOBAL_b32;
1321+
break;
1322+
default:
1323+
llvm_unreachable("Unhandled packed vector type");
1324+
}
1325+
}
13051326
EltVT = OrigType;
13061327
NumElts /= OrigType.getVectorNumElements();
13071328
}
@@ -1323,50 +1344,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13231344
SelectADDR(Op1, Base, Offset);
13241345
SDValue Ops[] = {Base, Offset, Chain};
13251346

1326-
std::optional<unsigned> Opcode;
1327-
switch (N->getOpcode()) {
1328-
default:
1329-
return false;
1330-
case ISD::LOAD:
1331-
Opcode = pickOpcodeForVT(
1332-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1333-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1334-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1335-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1336-
break;
1337-
case ISD::INTRINSIC_W_CHAIN:
1338-
Opcode = pickOpcodeForVT(
1339-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1340-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1341-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1342-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1343-
break;
1344-
case NVPTXISD::LoadV2:
1345-
Opcode = pickOpcodeForVT(
1346-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1347-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1348-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1349-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1350-
break;
1351-
case NVPTXISD::LDUV2:
1352-
Opcode = pickOpcodeForVT(
1353-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1354-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1355-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1356-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1357-
break;
1358-
case NVPTXISD::LoadV4:
1359-
Opcode = pickOpcodeForVT(
1360-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1361-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1362-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1363-
break;
1364-
case NVPTXISD::LDUV4:
1365-
Opcode = pickOpcodeForVT(
1366-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1367-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1368-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1369-
break;
1347+
if (!Opcode) {
1348+
switch (N->getOpcode()) {
1349+
default:
1350+
return false;
1351+
case ISD::LOAD:
1352+
Opcode = pickOpcodeForVT(
1353+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1354+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1355+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1356+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1357+
break;
1358+
case ISD::INTRINSIC_W_CHAIN:
1359+
Opcode = pickOpcodeForVT(
1360+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1361+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1362+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1363+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1364+
break;
1365+
case NVPTXISD::LoadV2:
1366+
Opcode = pickOpcodeForVT(
1367+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1368+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1369+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1370+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1371+
break;
1372+
case NVPTXISD::LDUV2:
1373+
Opcode = pickOpcodeForVT(
1374+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1375+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1376+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1377+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1378+
break;
1379+
case NVPTXISD::LoadV4:
1380+
Opcode = pickOpcodeForVT(
1381+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1382+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1383+
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1384+
break;
1385+
case NVPTXISD::LDUV4:
1386+
Opcode = pickOpcodeForVT(
1387+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1388+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1389+
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1390+
break;
1391+
}
13701392
}
13711393
if (!Opcode)
13721394
return false;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2239,7 +2239,9 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
22392239
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
22402240
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
22412241
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2242+
def INT_PTX_LDU_GLOBAL_b32 : LDU_G<"b32", Int32Regs>;
22422243
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2244+
def INT_PTX_LDU_GLOBAL_b64 : LDU_G<"b64", Int64Regs>;
22432245
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
22442246
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
22452247

@@ -2289,7 +2291,9 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
22892291
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
22902292
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
22912293
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2294+
def INT_PTX_LDG_GLOBAL_b32 : LDG_G<"b32", Int32Regs>;
22922295
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2296+
def INT_PTX_LDG_GLOBAL_b64 : LDG_G<"b64", Int64Regs>;
22932297
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
22942298
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
22952299

llvm/test/CodeGen/NVPTX/ldg-invariant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ define half @ld_global_v2f16(ptr addrspace(1) %ptr) {
3232
; CHECK-EMPTY:
3333
; CHECK-NEXT: // %bb.0:
3434
; CHECK-NEXT: ld.param.u64 %rd1, [ld_global_v2f16_param_0];
35-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
35+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
3636
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
3737
; CHECK-NEXT: cvt.f32.f16 %f1, %rs2;
3838
; CHECK-NEXT: cvt.f32.f16 %f2, %rs1;

llvm/test/CodeGen/NVPTX/ldu-ldg.ll

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ declare float @llvm.nvvm.ldu.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
1212
declare double @llvm.nvvm.ldu.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
1313
declare half @llvm.nvvm.ldu.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
1414
declare <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
15+
declare <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
1516

1617
declare i8 @llvm.nvvm.ldg.global.i.i8.p1(ptr addrspace(1) %ptr, i32 %align)
1718
declare i16 @llvm.nvvm.ldg.global.i.i16.p1(ptr addrspace(1) %ptr, i32 %align)
@@ -22,6 +23,7 @@ declare float @llvm.nvvm.ldg.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
2223
declare double @llvm.nvvm.ldg.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
2324
declare half @llvm.nvvm.ldg.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
2425
declare <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
26+
declare <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
2527

2628
define i8 @test_ldu_i8(ptr addrspace(1) %ptr) {
2729
; CHECK-LABEL: test_ldu_i8(
@@ -154,13 +156,27 @@ define <2 x half> @test_ldu_v2f16(ptr addrspace(1) %ptr) {
154156
; CHECK-EMPTY:
155157
; CHECK-NEXT: // %bb.0:
156158
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f16_param_0];
157-
; CHECK-NEXT: ldu.global.u32 %r1, [%rd1];
159+
; CHECK-NEXT: ldu.global.b32 %r1, [%rd1];
158160
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
159161
; CHECK-NEXT: ret;
160162
%val = tail call <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
161163
ret <2 x half> %val
162164
}
163165

166+
define <2 x float> @test_ldu_v2f32(ptr addrspace(1) %ptr) {
167+
; CHECK-LABEL: test_ldu_v2f32(
168+
; CHECK: {
169+
; CHECK-NEXT: .reg .b64 %rd<3>;
170+
; CHECK-EMPTY:
171+
; CHECK-NEXT: // %bb.0:
172+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f32_param_0];
173+
; CHECK-NEXT: ldu.global.b64 %rd2, [%rd1];
174+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
175+
; CHECK-NEXT: ret;
176+
%val = tail call <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
177+
ret <2 x float> %val
178+
}
179+
164180
define i8 @test_ldg_i8(ptr addrspace(1) %ptr) {
165181
; CHECK-LABEL: test_ldg_i8(
166182
; CHECK: {
@@ -291,13 +307,27 @@ define <2 x half> @test_ldg_v2f16(ptr addrspace(1) %ptr) {
291307
; CHECK-EMPTY:
292308
; CHECK-NEXT: // %bb.0:
293309
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f16_param_0];
294-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
310+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
295311
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
296312
; CHECK-NEXT: ret;
297313
%val = tail call <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
298314
ret <2 x half> %val
299315
}
300316

317+
define <2 x float> @test_ldg_v2f32(ptr addrspace(1) %ptr) {
318+
; CHECK-LABEL: test_ldg_v2f32(
319+
; CHECK: {
320+
; CHECK-NEXT: .reg .b64 %rd<3>;
321+
; CHECK-EMPTY:
322+
; CHECK-NEXT: // %bb.0:
323+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f32_param_0];
324+
; CHECK-NEXT: ld.global.nc.b64 %rd2, [%rd1];
325+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
326+
; CHECK-NEXT: ret;
327+
%val = tail call <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
328+
ret <2 x float> %val
329+
}
330+
301331
@g = addrspace(1) global i32 0
302332

303333
define i32 @test_ldg_asi() {

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ define ptx_kernel void @foo7(ptr noalias readonly %from, ptr %to) {
8282
; SM20-LABEL: .visible .entry foo8(
8383
; SM20: ld.global.u32
8484
; SM35-LABEL: .visible .entry foo8(
85-
; SM35: ld.global.nc.u32
85+
; SM35: ld.global.nc.b32
8686
define ptx_kernel void @foo8(ptr noalias readonly %from, ptr %to) {
8787
%1 = load <2 x i16>, ptr %from
8888
store <2 x i16> %1, ptr %to
@@ -132,7 +132,7 @@ define ptx_kernel void @foo12(ptr noalias readonly %from, ptr %to) {
132132
; SM20-LABEL: .visible .entry foo13(
133133
; SM20: ld.global.u32
134134
; SM35-LABEL: .visible .entry foo13(
135-
; SM35: ld.global.nc.u32
135+
; SM35: ld.global.nc.b32
136136
define ptx_kernel void @foo13(ptr noalias readonly %from, ptr %to) {
137137
%1 = load <4 x i8>, ptr %from
138138
store <4 x i8> %1, ptr %to

llvm/test/CodeGen/NVPTX/read-global-variable-constant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ define float @test_gv_float() {
1717

1818
; CHECK-LABEL: test_gv_float2()
1919
define <2 x float> @test_gv_float2() {
20-
; CHECK: ld.global.nc.v2.f32
20+
; CHECK: ld.global.nc.b64
2121
%v = load <2 x float>, ptr @gv_float2
2222
ret <2 x float> %v
2323
}

0 commit comments

Comments
 (0)