Skip to content

Commit e016449

Browse files
committed
promote extract_vector_elt nodes to unpacking mov
Also update the test cases.
1 parent d264e73 commit e016449

File tree

4 files changed

+2093
-13
lines changed

4 files changed

+2093
-13
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -450,10 +450,14 @@ bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
450450
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
451451
SDValue Vector = N->getOperand(0);
452452

453-
// We only care about 16x2 as it's the only real vector type we
454-
// need to deal with.
453+
// We only care about packed vector types: 16x2 and 32x2.
455454
MVT VT = Vector.getSimpleValueType();
456-
if (!Isv2x16VT(VT))
455+
unsigned NewOpcode = 0;
456+
if (Isv2x16VT(VT))
457+
NewOpcode = NVPTX::I32toV2I16;
458+
else if (VT == MVT::v2f32)
459+
NewOpcode = NVPTX::I64toV2F32;
460+
else
457461
return false;
458462
// Find and record all uses of this vector that extract element 0 or 1.
459463
SmallVector<SDNode *, 4> E0, E1;
@@ -473,16 +477,19 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
473477
}
474478
}
475479

476-
// There's no point scattering f16x2 if we only ever access one
480+
// There's no point scattering f16x2 or f32x2 if we only ever access one
477481
// element of it.
478482
if (E0.empty() || E1.empty())
479483
return false;
480484

481-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
482-
// into f16,f16 SplitF16x2(V)
485+
// Merge:
486+
// (f16 extractelt(V, 0), f16 extractelt(V,1))
487+
// -> f16,f16 SplitF16x2(V)
488+
// (f32 extractelt(V, 0), f32 extractelt(V,1))
489+
// -> f32,f32 SplitF32x2(V)
483490
MVT EltVT = VT.getVectorElementType();
484491
SDNode *ScatterOp =
485-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
492+
CurDAG->getMachineNode(NewOpcode, SDLoc(N), EltVT, EltVT, Vector);
486493
for (auto *Node : E0)
487494
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
488495
for (auto *Node : E1)

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5434,10 +5434,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
54345434
IsPTXVectorType(VectorVT.getSimpleVT()))
54355435
return SDValue(); // Native vector loads already combine nicely w/
54365436
// extract_vector_elt.
5437-
// Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5437+
// Don't mess with singletons or v2*16, v2f32, v4i8 and v8i8 types, we already
54385438
// handle them OK.
54395439
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5440-
VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5440+
VectorVT == MVT::v2f32 || VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
54415441
return SDValue();
54425442

54435443
// Don't mess with undef values as sra may be simplified to 0, not undef.

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3212,6 +3212,9 @@ let hasSideEffects = false in {
32123212
def I64toV2I32 : NVPTXInst<(outs Int32Regs:$d1, Int32Regs:$d2),
32133213
(ins Int64Regs:$s),
32143214
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
3215+
def I64toV2F32 : NVPTXInst<(outs Float32Regs:$d1, Float32Regs:$d2),
3216+
(ins Int64Regs:$s),
3217+
"mov.b64 \t{{$d1, $d2}}, $s;", []>;
32153218
def I128toV2I64: NVPTXInst<(outs Int64Regs:$d1, Int64Regs:$d2),
32163219
(ins Int128Regs:$s),
32173220
"mov.b128 \t{{$d1, $d2}}, $s;", []>;

0 commit comments

Comments
 (0)