Skip to content

Commit 8cd3d6c

Browse files
committed
[NVPTX] handle more cases for loads and stores
Split unaligned stores and loads of v2f32. Add DAGCombiner rules for: - target-independent stores that store a v2f32 BUILD_VECTOR. We scalarize the value and rewrite the store Fix test cases.
1 parent 12e1c60 commit 8cd3d6c

File tree

5 files changed

+52
-25
lines changed

5 files changed

+52
-25
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -833,7 +833,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
833833
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835835
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836-
ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
836+
ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD::BITCAST});
837837

838838
// setcc for f16x2 and bf16x2 needs special handling to prevent
839839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3092,10 +3092,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
30923092
if (Op.getValueType() == MVT::i1)
30933093
return LowerLOADi1(Op, DAG);
30943094

3095-
// v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3096-
// unaligned loads and have to handle it here.
3095+
// v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3096+
// handle unaligned loads and have to handle it here.
30973097
EVT VT = Op.getValueType();
3098-
if (Isv2x16VT(VT) || VT == MVT::v4i8) {
3098+
if (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32) {
30993099
LoadSDNode *Load = cast<LoadSDNode>(Op);
31003100
EVT MemVT = Load->getMemoryVT();
31013101
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -3139,15 +3139,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31393139
if (VT == MVT::i1)
31403140
return LowerSTOREi1(Op, DAG);
31413141

3142-
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
3143-
// stores and have to handle it here.
3144-
if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
3142+
// v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3143+
// handle unaligned stores and have to handle it here.
3144+
if ((Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32) &&
31453145
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
31463146
VT, *Store->getMemOperand()))
31473147
return expandUnalignedStore(Store, DAG);
31483148

3149-
// v2f16, v2bf16 and v2i16 don't need special handling.
3150-
if (Isv2x16VT(VT) || VT == MVT::v4i8)
3149+
// v2f16/v2bf16/v2i16/v4i8/v2f32 don't need special handling.
3150+
if (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32)
31513151
return SDValue();
31523152

31533153
if (VT.isVector())
@@ -3156,8 +3156,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31563156
return SDValue();
31573157
}
31583158

3159-
SDValue
3160-
NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3159+
static SDValue convertVectorStore(SDValue Op, SelectionDAG &DAG,
3160+
const SmallVectorImpl<SDValue> &Elements) {
31613161
SDNode *N = Op.getNode();
31623162
SDValue Val = N->getOperand(1);
31633163
SDLoc DL(N);
@@ -3224,6 +3224,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32243224
SDValue SubVector = DAG.getBuildVector(EltVT, DL, SubVectorElts);
32253225
Ops.push_back(SubVector);
32263226
}
3227+
} else if (!Elements.empty()) {
3228+
Ops.insert(Ops.end(), Elements.begin(), Elements.end());
32273229
} else {
32283230
for (unsigned i = 0; i < NumElts; ++i) {
32293231
SDValue ExtVal = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3241,10 +3243,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32413243
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
32423244
MemSD->getMemoryVT(), MemSD->getMemOperand());
32433245

3244-
// return DCI.CombineTo(N, NewSt, true);
32453246
return NewSt;
32463247
}
32473248

3249+
// Default variant where we don't pass in elements.
3250+
static SDValue convertVectorStore(SDValue Op, SelectionDAG &DAG) {
3251+
return convertVectorStore(Op, DAG, SmallVector<SDValue>{});
3252+
}
3253+
3254+
SDValue NVPTXTargetLowering::LowerSTOREVector(SDValue Op,
3255+
SelectionDAG &DAG) const {
3256+
return convertVectorStore(Op, DAG);
3257+
}
3258+
32483259
// st i1 v, addr
32493260
// =>
32503261
// v1 = zxt v to i16
@@ -5400,6 +5411,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54005411
// -->
54015412
// StoreRetvalV2 {a, b}
54025413
// likewise for V2 -> V4 case
5414+
//
5415+
// We also handle target-independent stores, which require us to first
5416+
// convert to StoreV2.
54035417

54045418
std::optional<NVPTXISD::NodeType> NewOpcode;
54055419
switch (N->getOpcode()) {
@@ -5425,8 +5439,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54255439
SDValue CurrentOp = N->getOperand(I);
54265440
if (CurrentOp->getOpcode() == ISD::BUILD_VECTOR) {
54275441
assert(CurrentOp.getValueType() == MVT::v2f32);
5428-
NewOps.push_back(CurrentOp.getNode()->getOperand(0));
5429-
NewOps.push_back(CurrentOp.getNode()->getOperand(1));
5442+
NewOps.push_back(CurrentOp.getOperand(0));
5443+
NewOps.push_back(CurrentOp.getOperand(1));
54305444
} else {
54315445
NewOps.clear();
54325446
break;
@@ -6197,6 +6211,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
61976211
return SDValue();
61986212
}
61996213

6214+
static SDValue PerformStoreCombine(SDNode *N,
6215+
TargetLowering::DAGCombinerInfo &DCI) {
6216+
// check if the store'd value can be scalarized
6217+
SDValue StoredVal = N->getOperand(1);
6218+
if (StoredVal.getValueType() == MVT::v2f32 &&
6219+
StoredVal.getOpcode() == ISD::BUILD_VECTOR) {
6220+
SmallVector<SDValue> Elements(StoredVal->op_values());
6221+
return convertVectorStore(SDValue(N, 0), DCI.DAG, Elements);
6222+
}
6223+
return SDValue();
6224+
}
6225+
62006226
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62016227
DAGCombinerInfo &DCI) const {
62026228
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -6226,6 +6252,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62266252
case NVPTXISD::LoadParam:
62276253
case NVPTXISD::LoadParamV2:
62286254
return PerformLoadCombine(N, DCI);
6255+
case ISD::STORE:
6256+
return PerformStoreCombine(N, DCI);
62296257
case NVPTXISD::StoreParam:
62306258
case NVPTXISD::StoreParamV2:
62316259
case NVPTXISD::StoreParamV4:

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
1010
; CHECK-LABEL: @test_v2f32
1111
%call = tail call <2 x float> @barv(<2 x float> %input)
1212
; CHECK: .param .align 8 .b8 retval0[8];
13-
; CHECK: ld.param.v2.f32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
13+
; CHECK: ld.param.b64 [[E0_1:%rd[0-9]+]], [retval0];
1414
store <2 x float> %call, ptr %output, align 8
15-
; CHECK: st.v2.f32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
15+
; CHECK: st.b64 [{{%rd[0-9]+}}], [[E0_1]]
1616
ret void
1717
}
1818

llvm/test/CodeGen/NVPTX/f32x2-instructions.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,13 @@ define <2 x float> @test_frem_ftz(<2 x float> %a, <2 x float> %b) #2 {
512512
define void @test_ldst_v2f32(ptr %a, ptr %b) #0 {
513513
; CHECK-LABEL: test_ldst_v2f32(
514514
; CHECK: {
515-
; CHECK-NEXT: .reg .f32 %f<3>;
516-
; CHECK-NEXT: .reg .b64 %rd<3>;
515+
; CHECK-NEXT: .reg .b64 %rd<4>;
517516
; CHECK-EMPTY:
518517
; CHECK-NEXT: // %bb.0:
519518
; CHECK-NEXT: ld.param.u64 %rd2, [test_ldst_v2f32_param_1];
520519
; CHECK-NEXT: ld.param.u64 %rd1, [test_ldst_v2f32_param_0];
521-
; CHECK-NEXT: ld.v2.f32 {%f1, %f2}, [%rd1];
522-
; CHECK-NEXT: st.v2.f32 [%rd2], {%f1, %f2};
520+
; CHECK-NEXT: ld.b64 %rd3, [%rd1];
521+
; CHECK-NEXT: st.b64 [%rd2], %rd3;
523522
; CHECK-NEXT: ret;
524523
%t1 = load <2 x float>, ptr %a
525524
store <2 x float> %t1, ptr %b, align 32

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,9 @@ define ptx_kernel void @foo10(ptr noalias readonly %from, ptr %to) {
110110
}
111111

112112
; SM20-LABEL: .visible .entry foo11(
113-
; SM20: ld.global.v2.f32
113+
; SM20: ld.global.b64
114114
; SM35-LABEL: .visible .entry foo11(
115-
; SM35: ld.global.nc.v2.f32
115+
; SM35: ld.global.nc.b64
116116
define ptx_kernel void @foo11(ptr noalias readonly %from, ptr %to) {
117117
%1 = load <2 x float>, ptr %from
118118
store <2 x float> %1, ptr %to

llvm/test/CodeGen/NVPTX/misaligned-vector-ldst.ll

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,15 @@ define <4 x float> @t1(ptr %p1) {
1818
define <4 x float> @t2(ptr %p1) {
1919
; CHECK-NOT: ld.v4
2020
; CHECK-NOT: ld.v2
21-
; CHECK: ld.f32
21+
; CHECK: ld.u32
2222
%r = load <4 x float>, ptr %p1, align 4
2323
ret <4 x float> %r
2424
}
2525

2626
; CHECK-LABEL: t3
2727
define <4 x float> @t3(ptr %p1) {
2828
; CHECK-NOT: ld.v4
29-
; CHECK: ld.v2
29+
; CHECK: ld.b64
3030
%r = load <4 x float>, ptr %p1, align 8
3131
ret <4 x float> %r
3232
}
@@ -111,7 +111,7 @@ define void @s1(ptr %p1, <4 x float> %v) {
111111
define void @s2(ptr %p1, <4 x float> %v) {
112112
; CHECK-NOT: st.v4
113113
; CHECK-NOT: st.v2
114-
; CHECK: st.f32
114+
; CHECK: st.u32
115115
store <4 x float> %v, ptr %p1, align 4
116116
ret void
117117
}

0 commit comments

Comments
 (0)