Skip to content

Commit 0f3fdc2

Browse files
committed
promote extract_vector_elt nodes to unpacking mov
Also update the test cases.
1 parent 5479c5a commit 0f3fdc2

File tree

4 files changed

+2093
-13
lines changed

4 files changed

+2093
-13
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -468,10 +468,14 @@ bool NVPTXDAGToDAGISel::tryUNPACK_VECTOR(SDNode *N) {
468468
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
469469
SDValue Vector = N->getOperand(0);
470470

471-
// We only care about 16x2 as it's the only real vector type we
472-
// need to deal with.
471+
// We only care about packed vector types: 16x2 and 32x2.
473472
MVT VT = Vector.getSimpleValueType();
474-
if (!Isv2x16VT(VT))
473+
unsigned NewOpcode;
474+
if (Isv2x16VT(VT))
475+
NewOpcode = NVPTX::I32toV2I16;
476+
else if (VT == MVT::v2f32)
477+
NewOpcode = NVPTX::I64toV2F32;
478+
else
475479
return false;
476480
// Find and record all uses of this vector that extract element 0 or 1.
477481
SmallVector<SDNode *, 4> E0, E1;
@@ -491,16 +495,19 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
491495
}
492496
}
493497

494-
// There's no point scattering f16x2 if we only ever access one
498+
// There's no point scattering f16x2 or f32x2 if we only ever access one
495499
// element of it.
496500
if (E0.empty() || E1.empty())
497501
return false;
498502

499-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
500-
// into f16,f16 SplitF16x2(V)
503+
// Merge:
504+
// (f16 extractelt(V, 0), f16 extractelt(V,1))
505+
// -> f16,f16 SplitF16x2(V)
506+
// (f32 extractelt(V, 0), f32 extractelt(V,1))
507+
// -> f32,f32 SplitF32x2(V)
501508
MVT EltVT = VT.getVectorElementType();
502509
SDNode *ScatterOp =
503-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
510+
CurDAG->getMachineNode(NewOpcode, SDLoc(N), EltVT, EltVT, Vector);
504511
for (auto *Node : E0)
505512
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
506513
for (auto *Node : E1)

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5665,10 +5665,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
56655665
IsPTXVectorType(VectorVT.getSimpleVT()))
56665666
return SDValue(); // Native vector loads already combine nicely w/
56675667
// extract_vector_elt.
5668-
// Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5668+
// Don't mess with singletons or v2*16, v2f32, v4i8 and v8i8 types, we already
56695669
// handle them OK.
56705670
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5671-
VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5671+
VectorVT == MVT::v2f32 || VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
56725672
return SDValue();
56735673

56745674
// Don't mess with undef values as sra may be simplified to 0, not undef.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,6 +2896,9 @@ let hasSideEffects = false in {
28962896
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
28972897
(ins Int64Regs:$s),
28982898
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
2899+
def I64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
2900+
(ins Int64Regs:$s),
2901+
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
28992902
def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
29002903
(ins Int128Regs:$s),
29012904
"mov.b128 \t{{$d1, $d2}}, $s;", []>;

0 commit comments

Comments
 (0)