Skip to content

Commit fb0dc77

Browse files
committed
[NVPTX] update combiner rule for more types of loads
Handle more loads, including ones with multiple proxy registers: - i64 = LOAD - i64 = LoadParam - v2f32,v2f32 = LoadParamV2 Also update the test cases. Because this is an optimization, it is not triggered for some of these tests that compile with no optimizations.
1 parent d409c9d commit fb0dc77

File tree

7 files changed

+304
-201
lines changed

7 files changed

+304
-201
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 96 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -5089,11 +5089,13 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
50895089
return SDValue();
50905090
}
50915091

5092+
/// OverrideVT - allows overriding result and memory type
50925093
static std::optional<std::pair<SDValue, SDValue>>
5093-
convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector) {
5094+
convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector,
5095+
std::optional<EVT> OverrideVT = std::nullopt) {
50945096
LoadSDNode *LD = cast<LoadSDNode>(N);
5095-
const EVT ResVT = LD->getValueType(0);
5096-
const EVT MemVT = LD->getMemoryVT();
5097+
const EVT ResVT = OverrideVT.value_or(LD->getValueType(0));
5098+
const EVT MemVT = OverrideVT.value_or(LD->getMemoryVT());
50975099

50985100
// If we're doing sign/zero extension as part of the load, avoid lowering to
50995101
// a LoadV node. TODO: consider relaxing this restriction.
@@ -5147,33 +5149,31 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector) {
51475149
// pass along the extension information
51485150
OtherOps.push_back(DAG.getIntPtrConstant(LD->getExtensionType(), DL));
51495151

5150-
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps,
5151-
LD->getMemoryVT(),
5152+
SDValue NewLD = DAG.getMemIntrinsicNode(Opcode, DL, LdResVTs, OtherOps, MemVT,
51525153
LD->getMemOperand());
5153-
5154-
SmallVector<SDValue> ScalarRes;
5155-
if (EltVT.isVector()) {
5156-
assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
5157-
assert(NumElts * EltVT.getVectorNumElements() ==
5158-
ResVT.getVectorNumElements());
5159-
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5160-
// into individual elements.
5161-
for (const unsigned I : llvm::seq(NumElts)) {
5162-
SDValue SubVector = NewLD.getValue(I);
5163-
DAG.ExtractVectorElements(SubVector, ScalarRes);
5164-
}
5165-
} else {
5166-
for (const unsigned I : llvm::seq(NumElts)) {
5167-
SDValue Res = NewLD.getValue(I);
5168-
if (LoadEltVT != EltVT)
5169-
Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
5170-
ScalarRes.push_back(Res);
5171-
}
5172-
}
5173-
51745154
SDValue LoadChain = NewLD.getValue(NumElts);
51755155

51765156
if (BuildVector) {
5157+
SmallVector<SDValue> ScalarRes;
5158+
if (EltVT.isVector()) {
5159+
assert(EVT(EltVT.getVectorElementType()) == ResVT.getVectorElementType());
5160+
assert(NumElts * EltVT.getVectorNumElements() ==
5161+
ResVT.getVectorNumElements());
5162+
// Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5163+
// into individual elements.
5164+
for (const unsigned I : llvm::seq(NumElts)) {
5165+
SDValue SubVector = NewLD.getValue(I);
5166+
DAG.ExtractVectorElements(SubVector, ScalarRes);
5167+
}
5168+
} else {
5169+
for (const unsigned I : llvm::seq(NumElts)) {
5170+
SDValue Res = NewLD.getValue(I);
5171+
if (LoadEltVT != EltVT)
5172+
Res = DAG.getNode(ISD::TRUNCATE, DL, EltVT, Res);
5173+
ScalarRes.push_back(Res);
5174+
}
5175+
}
5176+
51775177
const MVT BuildVecVT =
51785178
MVT::getVectorVT(EltVT.getScalarType(), ScalarRes.size());
51795179
SDValue BuildVec = DAG.getBuildVector(BuildVecVT, DL, ScalarRes);
@@ -5188,23 +5188,20 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector) {
51885188
static SDValue PerformLoadCombine(SDNode *N,
51895189
TargetLowering::DAGCombinerInfo &DCI) {
51905190
auto *MemN = cast<MemSDNode>(N);
5191-
EVT MemVT = MemN->getMemoryVT();
5192-
5193-
// ignore volatile loads
5194-
if (MemN->isVolatile())
5195-
return SDValue();
5196-
51975191
// only operate on vectors of f32s / i64s
5198-
if (!MemVT.isVector())
5192+
if (EVT MemVT = MemN->getMemoryVT();
5193+
!(MemVT == MVT::i64 ||
5194+
(MemVT.isVector() && (MemVT.getVectorElementType() == MVT::f32 ||
5195+
MemVT.getVectorElementType() == MVT::i64))))
51995196
return SDValue();
52005197

5201-
EVT ElementVT = MemVT.getVectorElementType();
5202-
if (!(ElementVT == MVT::f32 ||
5203-
(ElementVT == MVT::i64 && N->getOpcode() != ISD::LOAD)))
5204-
return SDValue();
5198+
const unsigned OrigNumResults =
5199+
llvm::count_if(N->values(), [](const auto &VT) {
5200+
return VT == MVT::i64 || VT == MVT::f32 || VT.isVector();
5201+
});
52055202

52065203
SmallDenseMap<SDNode *, unsigned> ExtractElts;
5207-
SDNode *ProxyReg = nullptr;
5204+
SmallVector<SDNode *> ProxyRegs(OrigNumResults, nullptr);
52085205
SmallVector<std::pair<SDNode *, unsigned /*offset*/>> WorkList{{N, 0}};
52095206
while (!WorkList.empty()) {
52105207
auto [V, Offset] = WorkList.pop_back_val();
@@ -5217,8 +5214,14 @@ static SDValue PerformLoadCombine(SDNode *N,
52175214

52185215
SDNode *User = U.getUser();
52195216
if (User->getOpcode() == NVPTXISD::ProxyReg) {
5217+
Offset = U.getResNo() * 2;
5218+
SDNode *&ProxyReg = ProxyRegs[Offset / 2];
5219+
5220+
// We shouldn't have multiple proxy regs for the same value from the
5221+
// load, but bail out anyway since we don't handle this.
52205222
if (ProxyReg)
5221-
return SDValue(); // bail out if we've seen a proxy reg?
5223+
return SDValue();
5224+
52225225
ProxyReg = User;
52235226
} else if (User->getOpcode() == ISD::BITCAST &&
52245227
User->getValueType(0) == MVT::v2f32 &&
@@ -5308,9 +5311,18 @@ static SDValue PerformLoadCombine(SDNode *N,
53085311
if (NewGlueIdx)
53095312
NewGlue = NewLoad.getValue(*NewGlueIdx);
53105313
} else if (N->getOpcode() == ISD::LOAD) { // rewrite a load
5311-
if (auto Result = convertVectorLoad(N, DCI.DAG, /*BuildVector=*/false)) {
5314+
std::optional<EVT> CastToType;
5315+
EVT ResVT = N->getValueType(0);
5316+
if (ResVT == MVT::i64) {
5317+
// ld.b64 is treated as a vector by subsequent code
5318+
CastToType = MVT::v2f32;
5319+
}
5320+
if (auto Result =
5321+
convertVectorLoad(N, DCI.DAG, /*BuildVector=*/false, CastToType)) {
53125322
std::tie(NewLoad, NewChain) = *Result;
5313-
NumElts = MemVT.getVectorNumElements();
5323+
NumElts =
5324+
CastToType.value_or(cast<MemSDNode>(NewLoad.getNode())->getMemoryVT())
5325+
.getVectorNumElements();
53145326
if (NewLoad->getValueType(NewLoad->getNumValues() - 1) == MVT::Glue)
53155327
NewGlue = NewLoad.getValue(NewLoad->getNumValues() - 1);
53165328
}
@@ -5322,54 +5334,65 @@ static SDValue PerformLoadCombine(SDNode *N,
53225334
// (3) begin rewriting uses
53235335
SmallVector<SDValue> NewOutputsF32;
53245336

5325-
if (ProxyReg) {
5326-
// scalarize proxyreg, but first rewrite all uses of chain and glue from the
5327-
// old load to the new load
5337+
if (llvm::any_of(ProxyRegs, [](const SDNode *PR) { return PR != nullptr; })) {
5338+
// scalarize proxy regs, but first rewrite all uses of chain and glue from
5339+
// the old load to the new load
53285340
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
53295341
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
53305342

5331-
// Update the new chain and glue to be old inputs to the proxyreg, if they
5332-
// came from an intervening instruction between this proxyreg and the
5333-
// original load (ex: callseq_end). Other than bitcasts and extractelts, we
5334-
// followed all other nodes by chain and glue accesses.
5335-
if (SDValue OldInChain = ProxyReg->getOperand(0); OldInChain.getNode() != N)
5343+
for (unsigned ProxyI = 0, ProxyE = ProxyRegs.size(); ProxyI != ProxyE;
5344+
++ProxyI) {
5345+
SDNode *ProxyReg = ProxyRegs[ProxyI];
5346+
5347+
// no proxy reg might mean this result is unused
5348+
if (!ProxyReg)
5349+
continue;
5350+
5351+
// Update the new chain and glue to be old inputs to the proxyreg, if they
5352+
// came from an intervening instruction between this proxyreg and the
5353+
// original load (ex: callseq_end). Other than bitcasts and extractelts,
5354+
// we followed all other nodes by chain and glue accesses.
5355+
if (SDValue OldInChain = ProxyReg->getOperand(0);
5356+
OldInChain.getNode() != N)
53365357
NewChain = OldInChain;
5337-
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
5358+
if (SDValue OldInGlue = ProxyReg->getOperand(2); OldInGlue.getNode() != N)
53385359
NewGlue = OldInGlue;
53395360

5340-
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5341-
// replace later
5342-
OldChain = SDValue(ProxyReg, 1);
5343-
OldGlue = SDValue(ProxyReg, 2);
5344-
5345-
// generate the scalar proxy regs
5346-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
5347-
SDValue ProxyRegElem =
5348-
DCI.DAG.getNode(NVPTXISD::ProxyReg, SDLoc(ProxyReg),
5349-
DCI.DAG.getVTList(MVT::f32, MVT::Other, MVT::Glue),
5350-
{NewChain, NewLoad.getValue(I), NewGlue});
5351-
NewChain = ProxyRegElem.getValue(1);
5352-
NewGlue = ProxyRegElem.getValue(2);
5353-
NewOutputsF32.push_back(ProxyRegElem);
5361+
// update OldChain, OldGlue to the outputs of ProxyReg, which we will
5362+
// replace later
5363+
OldChain = SDValue(ProxyReg, 1);
5364+
OldGlue = SDValue(ProxyReg, 2);
5365+
5366+
// generate the scalar proxy regs
5367+
for (unsigned I = 0, E = 2; I != E; ++I) {
5368+
SDValue ProxyRegElem = DCI.DAG.getNode(
5369+
NVPTXISD::ProxyReg, SDLoc(ProxyReg),
5370+
DCI.DAG.getVTList(MVT::f32, MVT::Other, MVT::Glue),
5371+
{NewChain, NewLoad.getValue(ProxyI * 2 + I), NewGlue});
5372+
NewChain = ProxyRegElem.getValue(1);
5373+
NewGlue = ProxyRegElem.getValue(2);
5374+
NewOutputsF32.push_back(ProxyRegElem);
5375+
}
5376+
5377+
// replace all uses of the glue and chain from the old proxy reg
5378+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5379+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
53545380
}
53555381
} else {
53565382
for (unsigned I = 0, E = NumElts; I != E; ++I)
53575383
if (NewLoad->getValueType(I) == MVT::f32)
53585384
NewOutputsF32.push_back(NewLoad.getValue(I));
5385+
5386+
// replace all glue and chain nodes
5387+
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5388+
if (OldGlue)
5389+
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
53595390
}
53605391

5361-
// now, for all extractelts, replace them with one of the new outputs
5392+
// replace all extractelts with the new outputs
53625393
for (auto &[Extract, Index] : ExtractElts)
53635394
DCI.CombineTo(Extract, NewOutputsF32[Index], false);
53645395

5365-
// now replace all glue and chain nodes
5366-
DCI.DAG.ReplaceAllUsesOfValueWith(OldChain, NewChain);
5367-
if (OldGlue)
5368-
DCI.DAG.ReplaceAllUsesOfValueWith(OldGlue, NewGlue);
5369-
5370-
// cleanup
5371-
if (ProxyReg)
5372-
DCI.recursivelyDeleteUnusedNodes(ProxyReg);
53735396
return SDValue();
53745397
}
53755398

Lines changed: 97 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
12
; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_35 | FileCheck %s
23
; RUN: %if ptxas %{ llc < %s -mtriple=nvptx64 -mcpu=sm_35 | %ptxas-verify %}
34

@@ -7,57 +8,122 @@ declare [2 x float] @bara([2 x float] %input)
78
declare {float, float} @bars({float, float} %input)
89

910
define void @test_v2f32(<2 x float> %input, ptr %output) {
10-
; CHECK-LABEL: @test_v2f32
11+
; CHECK-LABEL: test_v2f32(
12+
; CHECK: {
13+
; CHECK-NEXT: .reg .b32 %f<5>;
14+
; CHECK-NEXT: .reg .b64 %rd<3>;
15+
; CHECK-EMPTY:
16+
; CHECK-NEXT: // %bb.0:
17+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v2f32_param_0];
18+
; CHECK-NEXT: { // callseq 0, 0
19+
; CHECK-NEXT: .param .align 8 .b8 param0[8];
20+
; CHECK-NEXT: st.param.b64 [param0], %rd1;
21+
; CHECK-NEXT: .param .align 8 .b8 retval0[8];
22+
; CHECK-NEXT: call.uni (retval0),
23+
; CHECK-NEXT: barv,
24+
; CHECK-NEXT: (
25+
; CHECK-NEXT: param0
26+
; CHECK-NEXT: );
27+
; CHECK-NEXT: ld.param.v2.b32 {%f1, %f2}, [retval0];
28+
; CHECK-NEXT: } // callseq 0
29+
; CHECK-NEXT: ld.param.b64 %rd2, [test_v2f32_param_1];
30+
; CHECK-NEXT: st.v2.b32 [%rd2], {%f1, %f2};
31+
; CHECK-NEXT: ret;
1132
%call = tail call <2 x float> @barv(<2 x float> %input)
12-
; CHECK: .param .align 8 .b8 retval0[8];
13-
; CHECK: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
1433
store <2 x float> %call, ptr %output, align 8
15-
; CHECK: st.v2.b32 [{{%rd[0-9]+}}], {[[E0]], [[E1]]}
1634
ret void
1735
}
1836

1937
define void @test_v3f32(<3 x float> %input, ptr %output) {
20-
; CHECK-LABEL: @test_v3f32
21-
;
38+
; CHECK-LABEL: test_v3f32(
39+
; CHECK: {
40+
; CHECK-NEXT: .reg .b32 %f<10>;
41+
; CHECK-NEXT: .reg .b64 %rd<2>;
42+
; CHECK-EMPTY:
43+
; CHECK-NEXT: // %bb.0:
44+
; CHECK-NEXT: ld.param.v2.b32 {%f1, %f2}, [test_v3f32_param_0];
45+
; CHECK-NEXT: ld.param.b32 %f3, [test_v3f32_param_0+8];
46+
; CHECK-NEXT: { // callseq 1, 0
47+
; CHECK-NEXT: .param .align 16 .b8 param0[16];
48+
; CHECK-NEXT: st.param.v2.b32 [param0], {%f1, %f2};
49+
; CHECK-NEXT: st.param.b32 [param0+8], %f3;
50+
; CHECK-NEXT: .param .align 16 .b8 retval0[16];
51+
; CHECK-NEXT: call.uni (retval0),
52+
; CHECK-NEXT: barv3,
53+
; CHECK-NEXT: (
54+
; CHECK-NEXT: param0
55+
; CHECK-NEXT: );
56+
; CHECK-NEXT: ld.param.v2.b32 {%f4, %f5}, [retval0];
57+
; CHECK-NEXT: ld.param.b32 %f6, [retval0+8];
58+
; CHECK-NEXT: } // callseq 1
59+
; CHECK-NEXT: ld.param.b64 %rd1, [test_v3f32_param_1];
60+
; CHECK-NEXT: st.b32 [%rd1+8], %f6;
61+
; CHECK-NEXT: st.v2.b32 [%rd1], {%f4, %f5};
62+
; CHECK-NEXT: ret;
2263
%call = tail call <3 x float> @barv3(<3 x float> %input)
23-
; CHECK: .param .align 16 .b8 retval0[16];
24-
; CHECK-DAG: ld.param.v2.b32 {[[E0:%f[0-9]+]], [[E1:%f[0-9]+]]}, [retval0];
25-
; CHECK-DAG: ld.param.b32 [[E2:%f[0-9]+]], [retval0+8];
2664
; Make sure we don't load more values than than we need to.
27-
; CHECK-NOT: ld.param.b32 [[E3:%f[0-9]+]], [retval0+12];
2865
store <3 x float> %call, ptr %output, align 8
29-
; CHECK-DAG: st.b32 [{{%rd[0-9]}}+8],
30-
; -- This is suboptimal. We should do st.v2.f32 instead
31-
; of combining 2xf32 info i64.
32-
; CHECK-DAG: st.b64 [{{%rd[0-9]}}],
33-
; CHECK: ret;
3466
ret void
3567
}
3668

3769
define void @test_a2f32([2 x float] %input, ptr %output) {
38-
; CHECK-LABEL: @test_a2f32
70+
; CHECK-LABEL: test_a2f32(
71+
; CHECK: {
72+
; CHECK-NEXT: .reg .b32 %f<7>;
73+
; CHECK-NEXT: .reg .b64 %rd<2>;
74+
; CHECK-EMPTY:
75+
; CHECK-NEXT: // %bb.0:
76+
; CHECK-NEXT: ld.param.b32 %f1, [test_a2f32_param_0];
77+
; CHECK-NEXT: ld.param.b32 %f2, [test_a2f32_param_0+4];
78+
; CHECK-NEXT: { // callseq 2, 0
79+
; CHECK-NEXT: .param .align 4 .b8 param0[8];
80+
; CHECK-NEXT: st.param.b32 [param0], %f1;
81+
; CHECK-NEXT: st.param.b32 [param0+4], %f2;
82+
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
83+
; CHECK-NEXT: call.uni (retval0),
84+
; CHECK-NEXT: bara,
85+
; CHECK-NEXT: (
86+
; CHECK-NEXT: param0
87+
; CHECK-NEXT: );
88+
; CHECK-NEXT: ld.param.b32 %f3, [retval0];
89+
; CHECK-NEXT: ld.param.b32 %f4, [retval0+4];
90+
; CHECK-NEXT: } // callseq 2
91+
; CHECK-NEXT: ld.param.b64 %rd1, [test_a2f32_param_1];
92+
; CHECK-NEXT: st.b32 [%rd1+4], %f4;
93+
; CHECK-NEXT: st.b32 [%rd1], %f3;
94+
; CHECK-NEXT: ret;
3995
%call = tail call [2 x float] @bara([2 x float] %input)
40-
; CHECK: .param .align 4 .b8 retval0[8];
41-
; CHECK-DAG: ld.param.b32 [[ELEMA1:%f[0-9]+]], [retval0];
42-
; CHECK-DAG: ld.param.b32 [[ELEMA2:%f[0-9]+]], [retval0+4];
4396
store [2 x float] %call, ptr %output, align 4
44-
; CHECK: }
45-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMA1]]
46-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMA2]]
4797
ret void
48-
; CHECK: ret
4998
}
5099

51100
define void @test_s2f32({float, float} %input, ptr %output) {
52-
; CHECK-LABEL: @test_s2f32
101+
; CHECK-LABEL: test_s2f32(
102+
; CHECK: {
103+
; CHECK-NEXT: .reg .b32 %f<7>;
104+
; CHECK-NEXT: .reg .b64 %rd<2>;
105+
; CHECK-EMPTY:
106+
; CHECK-NEXT: // %bb.0:
107+
; CHECK-NEXT: ld.param.b32 %f1, [test_s2f32_param_0];
108+
; CHECK-NEXT: ld.param.b32 %f2, [test_s2f32_param_0+4];
109+
; CHECK-NEXT: { // callseq 3, 0
110+
; CHECK-NEXT: .param .align 4 .b8 param0[8];
111+
; CHECK-NEXT: st.param.b32 [param0], %f1;
112+
; CHECK-NEXT: st.param.b32 [param0+4], %f2;
113+
; CHECK-NEXT: .param .align 4 .b8 retval0[8];
114+
; CHECK-NEXT: call.uni (retval0),
115+
; CHECK-NEXT: bars,
116+
; CHECK-NEXT: (
117+
; CHECK-NEXT: param0
118+
; CHECK-NEXT: );
119+
; CHECK-NEXT: ld.param.b32 %f3, [retval0];
120+
; CHECK-NEXT: ld.param.b32 %f4, [retval0+4];
121+
; CHECK-NEXT: } // callseq 3
122+
; CHECK-NEXT: ld.param.b64 %rd1, [test_s2f32_param_1];
123+
; CHECK-NEXT: st.b32 [%rd1+4], %f4;
124+
; CHECK-NEXT: st.b32 [%rd1], %f3;
125+
; CHECK-NEXT: ret;
53126
%call = tail call {float, float} @bars({float, float} %input)
54-
; CHECK: .param .align 4 .b8 retval0[8];
55-
; CHECK-DAG: ld.param.b32 [[ELEMS1:%f[0-9]+]], [retval0];
56-
; CHECK-DAG: ld.param.b32 [[ELEMS2:%f[0-9]+]], [retval0+4];
57127
store {float, float} %call, ptr %output, align 4
58-
; CHECK: }
59-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}], [[ELEMS1]]
60-
; CHECK-DAG: st.b32 [{{%rd[0-9]+}}+4], [[ELEMS2]]
61128
ret void
62-
; CHECK: ret
63129
}

0 commit comments

Comments
 (0)