Skip to content

Commit 34ae98c

Browse files
committed
[NVPTX] handle more cases for loads and stores
Split unaligned stores and loads of v2f32. Add DAGCombiner rules for: - target-independent stores that store a v2f32 BUILD_VECTOR. We scalarize the value and rewrite the store Fix test cases.
1 parent 183ed41 commit 34ae98c

File tree

5 files changed

+52
-26
lines changed

5 files changed

+52
-26
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 42 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -829,7 +829,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829829
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831831
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832-
ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
832+
ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD::BITCAST});
833833

834834
// setcc for f16x2 and bf16x2 needs special handling to prevent
835835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3143,10 +3143,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
31433143
if (Op.getValueType() == MVT::i1)
31443144
return LowerLOADi1(Op, DAG);
31453145

3146-
// v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3147-
// unaligned loads and have to handle it here.
3146+
// v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3147+
// handle unaligned loads and have to handle it here.
31483148
EVT VT = Op.getValueType();
3149-
if (Isv2x16VT(VT) || VT == MVT::v4i8) {
3149+
if (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32) {
31503150
LoadSDNode *Load = cast<LoadSDNode>(Op);
31513151
EVT MemVT = Load->getMemoryVT();
31523152
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
@@ -3190,22 +3190,22 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31903190
if (VT == MVT::i1)
31913191
return LowerSTOREi1(Op, DAG);
31923192

3193-
// v2f16 is legal, so we can't rely on legalizer to handle unaligned
3194-
// stores and have to handle it here.
3195-
if ((Isv2x16VT(VT) || VT == MVT::v4i8) &&
3193+
// v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3194+
// handle unaligned stores and have to handle it here.
3195+
if ((Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32) &&
31963196
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
31973197
VT, *Store->getMemOperand()))
31983198
return expandUnalignedStore(Store, DAG);
31993199

3200-
// v2f16, v2bf16 and v2i16 don't need special handling.
3201-
if (Isv2x16VT(VT) || VT == MVT::v4i8)
3200+
// v2f16/v2bf16/v2i16/v4i8/v2f32 don't need special handling.
3201+
if (Isv2x16VT(VT) || VT == MVT::v4i8 || VT == MVT::v2f32)
32023202
return SDValue();
32033203

32043204
return LowerSTOREVector(Op, DAG);
32053205
}
32063206

3207-
SDValue
3208-
NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3207+
static SDValue convertVectorStore(SDValue Op, SelectionDAG &DAG,
3208+
const SmallVectorImpl<SDValue> &Elements) {
32093209
MemSDNode *N = cast<MemSDNode>(Op.getNode());
32103210
SDValue Val = N->getOperand(1);
32113211
SDLoc DL(N);
@@ -3266,6 +3266,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32663266
NumEltsPerSubVector);
32673267
Ops.push_back(DAG.getBuildVector(EltVT, DL, SubVectorElts));
32683268
}
3269+
} else if (!Elements.empty()) {
3270+
Ops.insert(Ops.end(), Elements.begin(), Elements.end());
32693271
} else {
32703272
SDValue V = DAG.getBitcast(MVT::getVectorVT(EltVT, NumElts), Val);
32713273
for (const unsigned I : llvm::seq(NumElts)) {
@@ -3289,10 +3291,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
32893291
DAG.getMemIntrinsicNode(Opcode, DL, DAG.getVTList(MVT::Other), Ops,
32903292
N->getMemoryVT(), N->getMemOperand());
32913293

3292-
// return DCI.CombineTo(N, NewSt, true);
32933294
return NewSt;
32943295
}
32953296

3297+
// Default variant where we don't pass in elements.
3298+
static SDValue convertVectorStore(SDValue Op, SelectionDAG &DAG) {
3299+
return convertVectorStore(Op, DAG, SmallVector<SDValue>{});
3300+
}
3301+
3302+
SDValue NVPTXTargetLowering::LowerSTOREVector(SDValue Op,
3303+
SelectionDAG &DAG) const {
3304+
return convertVectorStore(Op, DAG);
3305+
}
3306+
32963307
// st i1 v, addr
32973308
// =>
32983309
// v1 = zxt v to i16
@@ -5413,6 +5424,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54135424
// -->
54145425
// StoreRetvalV2 {a, b}
54155426
// likewise for V2 -> V4 case
5427+
//
5428+
// We also handle target-independent stores, which require us to first
5429+
// convert to StoreV2.
54165430

54175431
std::optional<NVPTXISD::NodeType> NewOpcode;
54185432
switch (N->getOpcode()) {
@@ -5438,8 +5452,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
54385452
SDValue CurrentOp = N->getOperand(I);
54395453
if (CurrentOp->getOpcode() == ISD::BUILD_VECTOR) {
54405454
assert(CurrentOp.getValueType() == MVT::v2f32);
5441-
NewOps.push_back(CurrentOp.getNode()->getOperand(0));
5442-
NewOps.push_back(CurrentOp.getNode()->getOperand(1));
5455+
NewOps.push_back(CurrentOp.getOperand(0));
5456+
NewOps.push_back(CurrentOp.getOperand(1));
54435457
} else {
54445458
NewOps.clear();
54455459
break;
@@ -6216,6 +6230,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
62166230
return SDValue();
62176231
}
62186232

6233+
static SDValue PerformStoreCombine(SDNode *N,
6234+
TargetLowering::DAGCombinerInfo &DCI) {
6235+
// check if the store'd value can be scalarized
6236+
SDValue StoredVal = N->getOperand(1);
6237+
if (StoredVal.getValueType() == MVT::v2f32 &&
6238+
StoredVal.getOpcode() == ISD::BUILD_VECTOR) {
6239+
SmallVector<SDValue> Elements(StoredVal->op_values());
6240+
return convertVectorStore(SDValue(N, 0), DCI.DAG, Elements);
6241+
}
6242+
return SDValue();
6243+
}
6244+
62196245
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62206246
DAGCombinerInfo &DCI) const {
62216247
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -6245,6 +6271,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
62456271
case NVPTXISD::LoadParam:
62466272
case NVPTXISD::LoadParamV2:
62476273
return PerformLoadCombine(N, DCI);
6274+
case ISD::STORE:
6275+
return PerformStoreCombine(N, DCI);
62486276
case NVPTXISD::StoreParam:
62496277
case NVPTXISD::StoreParamV2:
62506278
case NVPTXISD::StoreParamV4:

llvm/test/CodeGen/NVPTX/aggregate-return.ll

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,7 @@ declare {float, float} @bars({float, float} %input)
1010
define void @test_v2f32(<2 x float> %input, ptr %output) {
1111
; CHECK-LABEL: test_v2f32(
1212
; CHECK: {
13-
; CHECK-NEXT: .reg .b32 %f<5>;
14-
; CHECK-NEXT: .reg .b64 %rd<3>;
13+
; CHECK-NEXT: .reg .b64 %rd<5>;
1514
; CHECK-EMPTY:
1615
; CHECK-NEXT: // %bb.0:
1716
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
@@ -24,10 +23,10 @@ define void @test_v2f32(<2 x float> %input, ptr %output) {
2423
; CHECK-NEXT: (
2524
; CHECK-NEXT: param0
2625
; CHECK-NEXT: );
27-
; CHECK-NEXT: ld.param.v2.b32 {%f1, %f2}, [retval0];
26+
; CHECK-NEXT: ld.param.b64 %rd2, [retval0];
2827
; CHECK-NEXT: } // callseq 0
29-
; CHECK-NEXT: ld.param.b64 %rd2, [test_v2f32_param_1];
30-
; CHECK-NEXT: st.v2.b32 [%rd2], {%f1, %f2};
28+
; CHECK-NEXT: ld.param.b64 %rd4, [test_v2f32_param_1];
29+
; CHECK-NEXT: st.b64 [%rd4], %rd2;
3130
; CHECK-NEXT: ret;
3231
%call = tail call <2 x float> @barv(<2 x float> %input)
3332
store <2 x float> %call, ptr %output, align 8

llvm/test/CodeGen/NVPTX/f32x2-instructions.ll

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -512,14 +512,13 @@ define <2 x float> @test_frem_ftz(<2 x float> %a, <2 x float> %b) #2 {
512512
define void @test_ldst_v2f32(ptr %a, ptr %b) #0 {
513513
; CHECK-LABEL: test_ldst_v2f32(
514514
; CHECK: {
515-
; CHECK-NEXT: .reg .b32 %f<3>;
516-
; CHECK-NEXT: .reg .b64 %rd<3>;
515+
; CHECK-NEXT: .reg .b64 %rd<4>;
517516
; CHECK-EMPTY:
518517
; CHECK-NEXT: // %bb.0:
519518
; CHECK-NEXT: ld.param.b64 %rd2, [test_ldst_v2f32_param_1];
520519
; CHECK-NEXT: ld.param.b64 %rd1, [test_ldst_v2f32_param_0];
521-
; CHECK-NEXT: ld.v2.b32 {%f1, %f2}, [%rd1];
522-
; CHECK-NEXT: st.v2.b32 [%rd2], {%f1, %f2};
520+
; CHECK-NEXT: ld.b64 %rd3, [%rd1];
521+
; CHECK-NEXT: st.b64 [%rd2], %rd3;
523522
; CHECK-NEXT: ret;
524523
%t1 = load <2 x float>, ptr %a
525524
store <2 x float> %t1, ptr %b, align 32

llvm/test/CodeGen/NVPTX/load-with-non-coherent-cache.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,9 @@ define ptx_kernel void @foo10(ptr noalias readonly %from, ptr %to) {
108108
}
109109

110110
; SM20-LABEL: .visible .entry foo11(
111-
; SM20: ld.global.v2.b32
111+
; SM20: ld.global.b64
112112
; SM35-LABEL: .visible .entry foo11(
113-
; SM35: ld.global.nc.v2.b32
113+
; SM35: ld.global.nc.b64
114114
define ptx_kernel void @foo11(ptr noalias readonly %from, ptr %to) {
115115
%1 = load <2 x float>, ptr %from
116116
store <2 x float> %1, ptr %to

llvm/test/CodeGen/NVPTX/misaligned-vector-ldst.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ define <4 x float> @t2(ptr %p1) {
2626
; CHECK-LABEL: t3
2727
define <4 x float> @t3(ptr %p1) {
2828
; CHECK-NOT: ld.v4
29-
; CHECK: ld.v2
29+
; CHECK: ld.b64
3030
%r = load <4 x float>, ptr %p1, align 8
3131
ret <4 x float> %r
3232
}

0 commit comments

Comments
 (0)