Skip to content

Commit cea37b9

Browse files
committed
[NVPTX] support generic LDG/LDU for packed data types
Support ld.global.nc.b64/ldu.global.b64 for v2f32 and ld.global.nc.b32/ldu.global.b32 for v2f16/v2bf16/v2i16/v4i8 Update test cases.
1 parent 5af52a1 commit cea37b9

File tree

6 files changed

+106
-50
lines changed

6 files changed

+106
-50
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 66 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1278,6 +1278,9 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12781278
EltVT = MVT::i64;
12791279
NumElts = 2;
12801280
}
1281+
1282+
std::optional<unsigned> Opcode;
1283+
12811284
if (EltVT.isVector()) {
12821285
NumElts = EltVT.getVectorNumElements();
12831286
EltVT = EltVT.getVectorElementType();
@@ -1290,6 +1293,24 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
12901293
(EltVT == MVT::i8 && OrigType == MVT::v4i8)) {
12911294
assert(NumElts % OrigType.getVectorNumElements() == 0 &&
12921295
"NumElts must be divisible by the number of elts in subvectors");
1296+
if (N->getOpcode() == ISD::LOAD ||
1297+
N->getOpcode() == ISD::INTRINSIC_W_CHAIN) {
1298+
switch (OrigType.getSimpleVT().SimpleTy) {
1299+
case MVT::v2f32:
1300+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b64
1301+
: NVPTX::INT_PTX_LDU_GLOBAL_b64;
1302+
break;
1303+
case MVT::v2f16:
1304+
case MVT::v2bf16:
1305+
case MVT::v2i16:
1306+
case MVT::v4i8:
1307+
Opcode = N->getOpcode() == ISD::LOAD ? NVPTX::INT_PTX_LDG_GLOBAL_b32
1308+
: NVPTX::INT_PTX_LDU_GLOBAL_b32;
1309+
break;
1310+
default:
1311+
llvm_unreachable("Unhandled packed vector type");
1312+
}
1313+
}
12931314
EltVT = OrigType;
12941315
NumElts /= OrigType.getVectorNumElements();
12951316
}
@@ -1309,50 +1330,51 @@ bool NVPTXDAGToDAGISel::tryLDGLDU(SDNode *N) {
13091330
SelectADDR(Op1, Base, Offset);
13101331
SDValue Ops[] = {Base, Offset, Chain};
13111332

1312-
std::optional<unsigned> Opcode;
1313-
switch (N->getOpcode()) {
1314-
default:
1315-
return false;
1316-
case ISD::LOAD:
1317-
Opcode = pickOpcodeForVT(
1318-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1319-
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1320-
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1321-
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1322-
break;
1323-
case ISD::INTRINSIC_W_CHAIN:
1324-
Opcode = pickOpcodeForVT(
1325-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1326-
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1327-
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1328-
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1329-
break;
1330-
case NVPTXISD::LoadV2:
1331-
Opcode = pickOpcodeForVT(
1332-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1333-
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1334-
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1335-
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1336-
break;
1337-
case NVPTXISD::LDUV2:
1338-
Opcode = pickOpcodeForVT(
1339-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1340-
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1341-
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1342-
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1343-
break;
1344-
case NVPTXISD::LoadV4:
1345-
Opcode = pickOpcodeForVT(
1346-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1347-
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1348-
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1349-
break;
1350-
case NVPTXISD::LDUV4:
1351-
Opcode = pickOpcodeForVT(
1352-
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1353-
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1354-
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1355-
break;
1333+
if (!Opcode) {
1334+
switch (N->getOpcode()) {
1335+
default:
1336+
return false;
1337+
case ISD::LOAD:
1338+
Opcode = pickOpcodeForVT(
1339+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_GLOBAL_i8,
1340+
NVPTX::INT_PTX_LDG_GLOBAL_i16, NVPTX::INT_PTX_LDG_GLOBAL_i32,
1341+
NVPTX::INT_PTX_LDG_GLOBAL_i64, NVPTX::INT_PTX_LDG_GLOBAL_f32,
1342+
NVPTX::INT_PTX_LDG_GLOBAL_f64);
1343+
break;
1344+
case ISD::INTRINSIC_W_CHAIN:
1345+
Opcode = pickOpcodeForVT(
1346+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_GLOBAL_i8,
1347+
NVPTX::INT_PTX_LDU_GLOBAL_i16, NVPTX::INT_PTX_LDU_GLOBAL_i32,
1348+
NVPTX::INT_PTX_LDU_GLOBAL_i64, NVPTX::INT_PTX_LDU_GLOBAL_f32,
1349+
NVPTX::INT_PTX_LDU_GLOBAL_f64);
1350+
break;
1351+
case NVPTXISD::LoadV2:
1352+
Opcode = pickOpcodeForVT(
1353+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v2i8_ELE,
1354+
NVPTX::INT_PTX_LDG_G_v2i16_ELE, NVPTX::INT_PTX_LDG_G_v2i32_ELE,
1355+
NVPTX::INT_PTX_LDG_G_v2i64_ELE, NVPTX::INT_PTX_LDG_G_v2f32_ELE,
1356+
NVPTX::INT_PTX_LDG_G_v2f64_ELE);
1357+
break;
1358+
case NVPTXISD::LDUV2:
1359+
Opcode = pickOpcodeForVT(
1360+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v2i8_ELE,
1361+
NVPTX::INT_PTX_LDU_G_v2i16_ELE, NVPTX::INT_PTX_LDU_G_v2i32_ELE,
1362+
NVPTX::INT_PTX_LDU_G_v2i64_ELE, NVPTX::INT_PTX_LDU_G_v2f32_ELE,
1363+
NVPTX::INT_PTX_LDU_G_v2f64_ELE);
1364+
break;
1365+
case NVPTXISD::LoadV4:
1366+
Opcode = pickOpcodeForVT(
1367+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDG_G_v4i8_ELE,
1368+
NVPTX::INT_PTX_LDG_G_v4i16_ELE, NVPTX::INT_PTX_LDG_G_v4i32_ELE,
1369+
std::nullopt, NVPTX::INT_PTX_LDG_G_v4f32_ELE, std::nullopt);
1370+
break;
1371+
case NVPTXISD::LDUV4:
1372+
Opcode = pickOpcodeForVT(
1373+
EltVT.getSimpleVT().SimpleTy, NVPTX::INT_PTX_LDU_G_v4i8_ELE,
1374+
NVPTX::INT_PTX_LDU_G_v4i16_ELE, NVPTX::INT_PTX_LDU_G_v4i32_ELE,
1375+
std::nullopt, NVPTX::INT_PTX_LDU_G_v4f32_ELE, std::nullopt);
1376+
break;
1377+
}
13561378
}
13571379
if (!Opcode)
13581380
return false;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2305,7 +2305,9 @@ class LDU_G<string TyStr, NVPTXRegClass regclass>
23052305
def INT_PTX_LDU_GLOBAL_i8 : LDU_G<"u8", Int16Regs>;
23062306
def INT_PTX_LDU_GLOBAL_i16 : LDU_G<"u16", Int16Regs>;
23072307
def INT_PTX_LDU_GLOBAL_i32 : LDU_G<"u32", Int32Regs>;
2308+
def INT_PTX_LDU_GLOBAL_b32 : LDU_G<"b32", Int32Regs>;
23082309
def INT_PTX_LDU_GLOBAL_i64 : LDU_G<"u64", Int64Regs>;
2310+
def INT_PTX_LDU_GLOBAL_b64 : LDU_G<"b64", Int64Regs>;
23092311
def INT_PTX_LDU_GLOBAL_f32 : LDU_G<"f32", Float32Regs>;
23102312
def INT_PTX_LDU_GLOBAL_f64 : LDU_G<"f64", Float64Regs>;
23112313

@@ -2355,7 +2357,9 @@ class LDG_G<string TyStr, NVPTXRegClass regclass>
23552357
def INT_PTX_LDG_GLOBAL_i8 : LDG_G<"u8", Int16Regs>;
23562358
def INT_PTX_LDG_GLOBAL_i16 : LDG_G<"u16", Int16Regs>;
23572359
def INT_PTX_LDG_GLOBAL_i32 : LDG_G<"u32", Int32Regs>;
2360+
def INT_PTX_LDG_GLOBAL_b32 : LDG_G<"b32", Int32Regs>;
23582361
def INT_PTX_LDG_GLOBAL_i64 : LDG_G<"u64", Int64Regs>;
2362+
def INT_PTX_LDG_GLOBAL_b64 : LDG_G<"b64", Int64Regs>;
23592363
def INT_PTX_LDG_GLOBAL_f32 : LDG_G<"f32", Float32Regs>;
23602364
def INT_PTX_LDG_GLOBAL_f64 : LDG_G<"f64", Float64Regs>;
23612365

llvm/test/CodeGen/NVPTX/ldg-invariant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ define half @ld_global_v2f16(ptr addrspace(1) %ptr) {
3232
; CHECK-EMPTY:
3333
; CHECK-NEXT: // %bb.0:
3434
; CHECK-NEXT: ld.param.u64 %rd1, [ld_global_v2f16_param_0];
35-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
35+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
3636
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
3737
; CHECK-NEXT: cvt.f32.f16 %f1, %rs2;
3838
; CHECK-NEXT: cvt.f32.f16 %f2, %rs1;

llvm/test/CodeGen/NVPTX/ldu-ldg.ll

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ declare float @llvm.nvvm.ldu.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
1212
declare double @llvm.nvvm.ldu.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
1313
declare half @llvm.nvvm.ldu.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
1414
declare <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
15+
declare <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
1516

1617
declare i8 @llvm.nvvm.ldg.global.i.i8.p1(ptr addrspace(1) %ptr, i32 %align)
1718
declare i16 @llvm.nvvm.ldg.global.i.i16.p1(ptr addrspace(1) %ptr, i32 %align)
@@ -22,6 +23,7 @@ declare float @llvm.nvvm.ldg.global.f.f32.p1(ptr addrspace(1) %ptr, i32 %align)
2223
declare double @llvm.nvvm.ldg.global.f.f64.p1(ptr addrspace(1) %ptr, i32 %align)
2324
declare half @llvm.nvvm.ldg.global.f.f16.p1(ptr addrspace(1) %ptr, i32 %align)
2425
declare <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 %align)
26+
declare <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 %align)
2527

2628
define i8 @test_ldu_i8(ptr addrspace(1) %ptr) {
2729
; CHECK-LABEL: test_ldu_i8(
@@ -154,13 +156,27 @@ define <2 x half> @test_ldu_v2f16(ptr addrspace(1) %ptr) {
154156
; CHECK-EMPTY:
155157
; CHECK-NEXT: // %bb.0:
156158
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f16_param_0];
157-
; CHECK-NEXT: ldu.global.u32 %r1, [%rd1];
159+
; CHECK-NEXT: ldu.global.b32 %r1, [%rd1];
158160
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
159161
; CHECK-NEXT: ret;
160162
%val = tail call <2 x half> @llvm.nvvm.ldu.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
161163
ret <2 x half> %val
162164
}
163165

166+
define <2 x float> @test_ldu_v2f32(ptr addrspace(1) %ptr) {
167+
; CHECK-LABEL: test_ldu_v2f32(
168+
; CHECK: {
169+
; CHECK-NEXT: .reg .b64 %rd<3>;
170+
; CHECK-EMPTY:
171+
; CHECK-NEXT: // %bb.0:
172+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldu_v2f32_param_0];
173+
; CHECK-NEXT: ldu.global.b64 %rd2, [%rd1];
174+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
175+
; CHECK-NEXT: ret;
176+
%val = tail call <2 x float> @llvm.nvvm.ldu.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
177+
ret <2 x float> %val
178+
}
179+
164180
define i8 @test_ldg_i8(ptr addrspace(1) %ptr) {
165181
; CHECK-LABEL: test_ldg_i8(
166182
; CHECK: {
@@ -291,13 +307,27 @@ define <2 x half> @test_ldg_v2f16(ptr addrspace(1) %ptr) {
291307
; CHECK-EMPTY:
292308
; CHECK-NEXT: // %bb.0:
293309
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f16_param_0];
294-
; CHECK-NEXT: ld.global.nc.u32 %r1, [%rd1];
310+
; CHECK-NEXT: ld.global.nc.b32 %r1, [%rd1];
295311
; CHECK-NEXT: st.param.b32 [func_retval0], %r1;
296312
; CHECK-NEXT: ret;
297313
%val = tail call <2 x half> @llvm.nvvm.ldg.global.f.v2f16.p1(ptr addrspace(1) %ptr, i32 4)
298314
ret <2 x half> %val
299315
}
300316

317+
define <2 x float> @test_ldg_v2f32(ptr addrspace(1) %ptr) {
318+
; CHECK-LABEL: test_ldg_v2f32(
319+
; CHECK: {
320+
; CHECK-NEXT: .reg .b64 %rd<3>;
321+
; CHECK-EMPTY:
322+
; CHECK-NEXT: // %bb.0:
323+
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldg_v2f32_param_0];
324+
; CHECK-NEXT: ld.global.nc.b64 %rd2, [%rd1];
325+
; CHECK-NEXT: st.param.b64 [func_retval0], %rd2;
326+
; CHECK-NEXT: ret;
327+
%val = tail call <2 x float> @llvm.nvvm.ldg.global.f.v2f32.p1(ptr addrspace(1) %ptr, i32 8)
328+
ret <2 x float> %val
329+
}
330+
301331
@g = addrspace(1) global i32 0
302332

303333
define i32 @test_ldg_asi() {

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ define ptx_kernel void @foo7(ptr noalias readonly %from, ptr %to) {
8080
; SM20-LABEL: .visible .entry foo8(
8181
; SM20: ld.global.u32
8282
; SM35-LABEL: .visible .entry foo8(
83-
; SM35: ld.global.nc.u32
83+
; SM35: ld.global.nc.b32
8484
define ptx_kernel void @foo8(ptr noalias readonly %from, ptr %to) {
8585
%1 = load <2 x i16>, ptr %from
8686
store <2 x i16> %1, ptr %to
@@ -130,7 +130,7 @@ define ptx_kernel void @foo12(ptr noalias readonly %from, ptr %to) {
130130
; SM20-LABEL: .visible .entry foo13(
131131
; SM20: ld.global.u32
132132
; SM35-LABEL: .visible .entry foo13(
133-
; SM35: ld.global.nc.u32
133+
; SM35: ld.global.nc.b32
134134
define ptx_kernel void @foo13(ptr noalias readonly %from, ptr %to) {
135135
%1 = load <4 x i8>, ptr %from
136136
store <4 x i8> %1, ptr %to

llvm/test/CodeGen/NVPTX/read-global-variable-constant.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ define float @test_gv_float() {
1717

1818
; CHECK-LABEL: test_gv_float2()
1919
define <2 x float> @test_gv_float2() {
20-
; CHECK: ld.global.nc.v2.f32
20+
; CHECK: ld.global.nc.b64
2121
%v = load <2 x float>, ptr @gv_float2
2222
ret <2 x float> %v
2323
}

0 commit comments

Comments
 (0)