Skip to content

Commit 600af45

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
1 parent 592bd16 commit 600af45

File tree

2 files changed

+105
-53
lines changed

2 files changed

+105
-53
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 75 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1256,18 +1256,39 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12561256
EltVT = MVT::i64;
12571257
NumElts = 2;
12581258
}
1259+
1260+
std::optional<unsigned> Opcode;
1261+
12591262
if (EltVT.isVector()) {
12601263
NumElts = EltVT.getVectorNumElements();
12611264
EltVT = EltVT.getVectorElementType();
1262-
// vectors of 8/16bits type are loaded/stored as multiples of v4i8/v2x16
1263-
// elements.
1265+
// vectors of 8/16/32bits type are loaded/stored as multiples of
1266+
// v4i8/v2x16/v2x32 elements.
12641267
if ((EltVT == MVT::f32 && OrigType == MVT::v2f32) ||
12651268
(EltVT == MVT::f16 && OrigType == MVT::v2f16) ||
12661269
(EltVT == MVT::bf16 && OrigType == MVT::v2bf16) ||
12671270
(EltVT == MVT::i16 && OrigType == MVT::v2i16) ||
12681271
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12691272
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
12701273
"NumElts must be divisible by the number of elts in subvectors");
1274+
if (N->getOpcode() == ISD::LOAD ||
1275+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1276+
switch (OrigType.getSimpleVT().SimpleTy) {
1277+
case MVT::v2f32:
1278+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_i64
1279+
: NVPTX::INT_PTX_LDU_GLOBAL_i64;
1280+
break;
1281+
case MVT::v2f16:
1282+
case MVT::v2bf16:
1283+
case MVT::v2i16:
1284+
case MVT::v4i8:
1285+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_i32
1286+
: NVPTX::INT_PTX_LDU_GLOBAL_i32;
1287+
break;
1288+
default:
1289+
llvm_unreachable("Unhandled packed vector type");
1290+
}
1291+
}
12711292
EltVT = OrigType;
12721293
NumElts /= OrigType.getVectorNumElements();
12731294
}
@@ -1287,57 +1308,58 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12871308
SelectADDR(Op1, Base, Offset);
12881309
SDValue Ops[] = {Base, Offset, Chain};
12891310

1290-
std::optional<unsigned> Opcode;
1291-
switch (N->getOpcode()) {
1292-
default:
1293-
return false;
1294-
case ISD::LOAD:
1295-
Opcode = pickOpcodeForVT(
1296-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1297-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1298-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1299-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1300-
break;
1301-
case ISD::INTRINSIC_W_CHAIN:
1302-
Opcode = pickOpcodeForVT(
1303-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1304-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1305-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1306-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1307-
break;
1308-
case NVPTXISD::LoadV2:
1309-
Opcode = pickOpcodeForVT(
1310-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1311-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1312-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1313-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1314-
break;
1315-
case NVPTXISD::LDUV2:
1316-
Opcode = pickOpcodeForVT(
1317-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1318-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1319-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1320-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1321-
break;
1322-
case NVPTXISD::LoadV4:
1323-
Opcode = pickOpcodeForVT(
1324-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1325-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1326-
NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
1327-
NVPTX::INT_PTX_LDG_G_v4f64_ELE);
1328-
break;
1329-
case NVPTXISD::LDUV4:
1330-
Opcode = pickOpcodeForVT(
1331-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1332-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1333-
{/* no v4i64 */}, NVPTX::INT_PTX_LDU_G_v4f32_ELE, {/* no v4f64 */});
1334-
break;
1335-
case NVPTXISD::LoadV8:
1336-
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
1337-
{/* no v8i16 */}, NVPTX::INT_PTX_LDG_G_v8i32_ELE,
1338-
{/* no v8i64 */}, NVPTX::INT_PTX_LDG_G_v8f32_ELE,
1339-
{/* no v8f64 */});
1340-
break;
1311+
if (!Opcode) {
1312+
switch (N->getOpcode()) {
1313+
default:
1314+
return false;
1315+
case ISD::LOAD:
1316+
Opcode = pickOpcodeForVT(
1317+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1318+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1319+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1320+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1321+
break;
1322+
case ISD::INTRINSIC_W_CHAIN:
1323+
Opcode = pickOpcodeForVT(
1324+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1325+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1326+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1327+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1328+
break;
1329+
case NVPTXISD::LoadV2:
1330+
Opcode = pickOpcodeForVT(
1331+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1332+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1333+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1334+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1335+
break;
1336+
case NVPTXISD::LDUV2:
1337+
Opcode = pickOpcodeForVT(
1338+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1339+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1340+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1341+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1342+
break;
1343+
case NVPTXISD::LoadV4:
1344+
Opcode = pickOpcodeForVT(
1345+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1346+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1347+
NVPTX::INT_PTX_LDG_G_v4i64_ELE, NVPTX::INT_PTX_LDG_G_v4f32_ELE,
1348+
NVPTX::INT_PTX_LDG_G_v4f64_ELE);
1349+
break;
1350+
case NVPTXISD::LDUV4:
1351+
Opcode = pickOpcodeForVT(
1352+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1353+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1354+
{/* no v4i64 */}, NVPTX::INT_PTX_LDU_G_v4f32_ELE, {/* no v4f64 */});
1355+
break;
1356+
case NVPTXISD::LoadV8:
1357+
Opcode = pickOpcodeForVT(EltVT.getSimpleVT().SimpleTy, {/* no v8i8 */},
1358+
{/* no v8i16 */}, NVPTX::INT_PTX_LDG_G_v8i32_ELE,
1359+
{/* no v8i64 */}, NVPTX::INT_PTX_LDG_G_v8f32_ELE,
1360+
{/* no v8f64 */});
1361+
break;
1362+
}
13411363
}
13421364
if (!Opcode)
13431365
return false;

llvm/test/CodeGen/NVPTX/ldu-ldg.ll

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ declare float @llvm.nvvm.ldu.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
1212
declare double @llvm.nvvm.ldu.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
1313
declare half @llvm.nvvm.ldu.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
1414
declare <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
15+
declare <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
1516

1617
declare i8 @llvm.nvvm.ldg.global.i.i8.p1(ptr addrspace(1) %ptr, i32 %align)
1718
declare i16 @llvm.nvvm.ldg.global.i.i16.p1(ptr addrspace(1) %ptr, i32 %align)
@@ -22,6 +23,7 @@ declare float @llvm.nvvm.ldg.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
2223
declare double @llvm.nvvm.ldg.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
2324
declare half @llvm.nvvm.ldg.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
2425
declare <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
26+
declare <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
2527

2628
define i8 @test_ldu_i8(ptr addrspace(1) %ptr) {
2729
; CHECK-LABEL: test_ldu_i8(
@@ -160,6 +162,20 @@ define <2 x half> @test_ldu_v2f16(ptr addrspace(1) %ptr) {
160162
ret <2 x half> %val
161163
}
162164

165+
define <2 x float> @test_ldu_v2f32(ptr addrspace(1) %ptr) {
166+
; CHECK-LABEL: test_ldu_v2f32(
167+
; CHECK: {
168+
; CHECK-NEXT: .reg .b64 %rd<3>;
169+
; CHECK-EMPTY:
170+
; CHECK-NEXT: // %bb.0:
171+
; CHECK-NEXT: ld.param.b64 %rd1, [test_ldu_v2f32_param_0];
172+
; CHECK-NEXT: ldu.global.b64 %rd2, [%rd1];
173+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
174+
; CHECK-NEXT: ret;
175+
%val = tail call <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
176+
ret <2 x float> %val
177+
}
178+
163179
define i8 @test_ldg_i8(ptr addrspace(1) %ptr) {
164180
; CHECK-LABEL: test_ldg_i8(
165181
; CHECK: {
@@ -296,6 +312,20 @@ define <2 x half> @test_ldg_v2f16(ptr addrspace(1) %ptr) {
296312
ret <2 x half> %val
297313
}
298314

315+
define <2 x float> @test_ldg_v2f32(ptr addrspace(1) %ptr) {
316+
; CHECK-LABEL: test_ldg_v2f32(
317+
; CHECK: {
318+
; CHECK-NEXT: .reg .b64 %rd<3>;
319+
; CHECK-EMPTY:
320+
; CHECK-NEXT: // %bb.0:
321+
; CHECK-NEXT: ld.param.b64 %rd1, [test_ldg_v2f32_param_0];
322+
; CHECK-NEXT: ld.global.nc.b64 %rd2, [%rd1];
323+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
324+
; CHECK-NEXT: ret;
325+
%val = tail call <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
326+
ret <2 x float> %val
327+
}
328+
299329
@g = addrspace(1) global i32 0
300330

301331
define i32 @test_ldg_asi() {

0 commit comments

Comments
 (0)