Skip to content

Commit baef6cd

Browse files
committed
promote extract_vector_elt nodes to unpacking mov
Also update the test cases.
1 parent 992f236 commit baef6cd

File tree

4 files changed

+2093
-13
lines changed

4 files changed

+2093
-13
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -464,10 +464,14 @@ bool NVPTXDAGToDAGISel::tryUNPACK_VECTOR(SDNode *N) {
464464
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
465465
SDValue Vector = N->getOperand(0);
466466

467-
// We only care about 16x2 as it's the only real vector type we
468-
// need to deal with.
467+
// We only care about packed vector types: 16x2 and 32x2.
469468
MVT VT = Vector.getSimpleValueType();
470-
if (!Isv2x16VT(VT))
469+
unsigned NewOpcode = 0;
470+
if (Isv2x16VT(VT))
471+
NewOpcode = NVPTX::I32toV2I16;
472+
else if (VT == MVT::v2f32)
473+
NewOpcode = NVPTX::I64toV2F32;
474+
else
471475
return false;
472476
// Find and record all uses of this vector that extract element 0 or 1.
473477
SmallVector<SDNode *, 4> E0, E1;
@@ -487,16 +491,19 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
487491
}
488492
}
489493

490-
// There's no point scattering f16x2 if we only ever access one
494+
// There's no point scattering f16x2 or f32x2 if we only ever access one
491495
// element of it.
492496
if (E0.empty() || E1.empty())
493497
return false;
494498

495-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
496-
// into f16,f16 SplitF16x2(V)
499+
// Merge:
500+
// (f16 extractelt(V, 0), f16 extractelt(V,1))
501+
// -> f16,f16 SplitF16x2(V)
502+
// (f32 extractelt(V, 0), f32 extractelt(V,1))
503+
// -> f32,f32 SplitF32x2(V)
497504
MVT EltVT = VT.getVectorElementType();
498505
SDNode *ScatterOp =
499-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
506+
CurDAG->getMachineNode(NewOpcode, SDLoc(N), EltVT, EltVT, Vector);
500507
for (auto *Node : E0)
501508
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
502509
for (auto *Node : E1)

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5545,10 +5545,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
55455545
IsPTXVectorType(VectorVT.getSimpleVT()))
55465546
return SDValue(); // Native vector loads already combine nicely w/
55475547
// extract_vector_elt.
5548-
// Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5548+
// Don't mess with singletons or v2*16, v2f32, v4i8 and v8i8 types, we already
55495549
// handle them OK.
55505550
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5551-
VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5551+
VectorVT == MVT::v2f32 || VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
55525552
return SDValue();
55535553

55545554
// Don't mess with undef values as sra may be simplified to 0, not undef.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3028,6 +3028,9 @@ let hasSideEffects = false in {
30283028
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
30293029
(ins Int64Regs:$s),
30303030
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
3031+
def I64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
3032+
(ins Int64Regs:$s),
3033+
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
30313034
def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
30323035
(ins Int128Regs:$s),
30333036
"mov.b128 \t{{$d1, $d2}}, $s;", []>;

0 commit comments

Comments
 (0)