@@ -1290,6 +1290,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1290
1290
EVT OrigType = N->getValueType (0 );
1291
1291
EVT EltVT = Mem->getMemoryVT ();
1292
1292
unsigned NumElts = 1 ;
1293
+
1294
+ std::optional<unsigned > Opcode;
1295
+
1293
1296
if (EltVT.isVector ()) {
1294
1297
NumElts = EltVT.getVectorNumElements ();
1295
1298
EltVT = EltVT.getVectorElementType ();
@@ -1302,6 +1305,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1302
1305
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
1303
1306
assert (NumElts % OrigType.getVectorNumElements () == 0 &&
1304
1307
" NumElts must be divisible by the number of elts in subvectors" );
1308
+ if (N->getOpcode () == ISD::LOAD ||
1309
+ N->getOpcode () == ISD::INTRINSIC_W_CHAIN) {
1310
+ switch (OrigType.getSimpleVT ().SimpleTy ) {
1311
+ case MVT::v2f32:
1312
+ Opcode = N->getOpcode () == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1313
+ : NVPTX::INT_PTX_LDU_GLOBAL_b64;
1314
+ break ;
1315
+ case MVT::v2f16:
1316
+ case MVT::v2bf16:
1317
+ case MVT::v2i16:
1318
+ case MVT::v4i8:
1319
+ Opcode = N->getOpcode () == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1320
+ : NVPTX::INT_PTX_LDU_GLOBAL_b32;
1321
+ break ;
1322
+ default :
1323
+ llvm_unreachable (" Unhandled packed vector type" );
1324
+ }
1325
+ }
1305
1326
EltVT = OrigType;
1306
1327
NumElts /= OrigType.getVectorNumElements ();
1307
1328
}
@@ -1323,50 +1344,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
1323
1344
SelectADDR (Op1, Base, Offset);
1324
1345
SDValue Ops[] = {Base, Offset, Chain};
1325
1346
1326
- std::optional<unsigned > Opcode;
1327
- switch (N->getOpcode ()) {
1328
- default :
1329
- return false ;
1330
- case ISD::LOAD:
1331
- Opcode = pickOpcodeForVT (
1332
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_GLOBAL_i8,
1333
- NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1334
- NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1335
- NVPTX::INT_PTX_LDG_GLOBAL_f64);
1336
- break ;
1337
- case ISD::INTRINSIC_W_CHAIN:
1338
- Opcode = pickOpcodeForVT (
1339
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_GLOBAL_i8,
1340
- NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1341
- NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1342
- NVPTX::INT_PTX_LDU_GLOBAL_f64);
1343
- break ;
1344
- case NVPTXISD::LoadV2:
1345
- Opcode = pickOpcodeForVT (
1346
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1347
- NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1348
- NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1349
- NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1350
- break ;
1351
- case NVPTXISD::LDUV2:
1352
- Opcode = pickOpcodeForVT (
1353
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1354
- NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1355
- NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1356
- NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1357
- break ;
1358
- case NVPTXISD::LoadV4:
1359
- Opcode = pickOpcodeForVT (
1360
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1361
- NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1362
- std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1363
- break ;
1364
- case NVPTXISD::LDUV4:
1365
- Opcode = pickOpcodeForVT (
1366
- EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1367
- NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1368
- std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1369
- break ;
1347
+ if (!Opcode) {
1348
+ switch (N->getOpcode ()) {
1349
+ default :
1350
+ return false ;
1351
+ case ISD::LOAD:
1352
+ Opcode = pickOpcodeForVT (
1353
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_GLOBAL_i8,
1354
+ NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1355
+ NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1356
+ NVPTX::INT_PTX_LDG_GLOBAL_f64);
1357
+ break ;
1358
+ case ISD::INTRINSIC_W_CHAIN:
1359
+ Opcode = pickOpcodeForVT (
1360
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_GLOBAL_i8,
1361
+ NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1362
+ NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1363
+ NVPTX::INT_PTX_LDU_GLOBAL_f64);
1364
+ break ;
1365
+ case NVPTXISD::LoadV2:
1366
+ Opcode = pickOpcodeForVT (
1367
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1368
+ NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1369
+ NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1370
+ NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1371
+ break ;
1372
+ case NVPTXISD::LDUV2:
1373
+ Opcode = pickOpcodeForVT (
1374
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1375
+ NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1376
+ NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1377
+ NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1378
+ break ;
1379
+ case NVPTXISD::LoadV4:
1380
+ Opcode = pickOpcodeForVT (
1381
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1382
+ NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1383
+ std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1384
+ break ;
1385
+ case NVPTXISD::LDUV4:
1386
+ Opcode = pickOpcodeForVT (
1387
+ EltVT.getSimpleVT ().SimpleTy , NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1388
+ NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1389
+ std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1390
+ break ;
1391
+ }
1370
1392
}
1371
1393
if (!Opcode)
1372
1394
return false ;
0 commit comments