Skip to content

Commit 0eec8d6

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
1 parent fb0dc77 commit 0eec8d6

File tree

4 files changed

+278
-226
lines changed

4 files changed

+278
-226
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,6 +1245,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12451245
EltVT = MVT::i64;
12461246
NumElts = 2;
12471247
}
1248+
1249+
std::optional<unsigned> Opcode;
1250+
12481251
if (EltVT.isVector()) {
12491252
NumElts = EltVT.getVectorNumElements();
12501253
EltVT = EltVT.getVectorElementType();
@@ -1257,6 +1260,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12571260
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12581261
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
12591262
"NumElts must be divisible by the number of elts in subvectors");
1263+
if (N->getOpcode() == ISD::LOAD ||
1264+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1265+
switch (OrigType.getSimpleVT().SimpleTy) {
1266+
case MVT::v2f32:
1267+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_i64
1268+
: NVPTX::INT_PTX_LDU_GLOBAL_i64;
1269+
break;
1270+
case MVT::v2f16:
1271+
case MVT::v2bf16:
1272+
case MVT::v2i16:
1273+
case MVT::v4i8:
1274+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_i32
1275+
: NVPTX::INT_PTX_LDU_GLOBAL_i32;
1276+
break;
1277+
default:
1278+
llvm_unreachable("Unhandled packed vector type");
1279+
}
1280+
}
12601281
EltVT = OrigType;
12611282
NumElts /= OrigType.getVectorNumElements();
12621283
}
@@ -1276,50 +1297,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12761297
SelectADDR(Op1, Base, Offset);
12771298
SDValue Ops[] = {Base, Offset, Chain};
12781299

1279-
std::optional<unsigned> Opcode;
1280-
switch (N->getOpcode()) {
1281-
default:
1282-
return false;
1283-
case ISD::LOAD:
1284-
Opcode = pickOpcodeForVT(
1285-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1286-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1287-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1288-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1289-
break;
1290-
case ISD::INTRINSIC_W_CHAIN:
1291-
Opcode = pickOpcodeForVT(
1292-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1293-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1294-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1295-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1296-
break;
1297-
case NVPTXISD::LoadV2:
1298-
Opcode = pickOpcodeForVT(
1299-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1300-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1301-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1302-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1303-
break;
1304-
case NVPTXISD::LDUV2:
1305-
Opcode = pickOpcodeForVT(
1306-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1307-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1308-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1309-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1310-
break;
1311-
case NVPTXISD::LoadV4:
1312-
Opcode = pickOpcodeForVT(
1313-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1314-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1315-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1316-
break;
1317-
case NVPTXISD::LDUV4:
1318-
Opcode = pickOpcodeForVT(
1319-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1320-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1321-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1322-
break;
1300+
if (!Opcode) {
1301+
switch (N->getOpcode()) {
1302+
default:
1303+
return false;
1304+
case ISD::LOAD:
1305+
Opcode = pickOpcodeForVT(
1306+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1307+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1308+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1309+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1310+
break;
1311+
case ISD::INTRINSIC_W_CHAIN:
1312+
Opcode = pickOpcodeForVT(
1313+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1314+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1315+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1316+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1317+
break;
1318+
case NVPTXISD::LoadV2:
1319+
Opcode = pickOpcodeForVT(
1320+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1321+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1322+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1323+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1324+
break;
1325+
case NVPTXISD::LDUV2:
1326+
Opcode = pickOpcodeForVT(
1327+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1328+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1329+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1330+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1331+
break;
1332+
case NVPTXISD::LoadV4:
1333+
Opcode = pickOpcodeForVT(
1334+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1335+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1336+
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1337+
break;
1338+
case NVPTXISD::LDUV4:
1339+
Opcode = pickOpcodeForVT(
1340+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1341+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1342+
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1343+
break;
1344+
}
13231345
}
13241346
if (!Opcode)
13251347
return false;

0 commit comments

Comments
 (0)