Skip to content

Commit 1fdbe69

Browse files
authored
[NVPTX] support f32x2 instructions for sm_100+ (#126337)
Lower `fadd`, `fsub`, `fmul`, and `fma` to f32x2 variants introduced in PTX 8.6 for sm_100+. Adds a new register class for v2f32 as a b64 register in PTX. This causes other vector operations like loads and stores to lower as .b64 instead of .v2.b32 as appropriate. Also update test cases to use the autogenerator.
1 parent 0a34309 commit 1fdbe69

27 files changed

+3544
-1125
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -446,11 +446,18 @@ bool NVPTXDAGToDAGISel::tryUNPACK_VECTOR(SDNode *N) {
446446
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
447447
SDValue Vector = N->getOperand(0);
448448

449-
// We only care about 16x2 as it's the only real vector type we
450-
// need to deal with.
451449
MVT VT = Vector.getSimpleValueType();
452-
if (!Isv2x16VT(VT))
450+
if (!(NVPTX::isPackedVectorTy(VT) && VT.getVectorNumElements() == 2))
453451
return false;
452+
453+
unsigned Opcode;
454+
if (VT.is32BitVector())
455+
Opcode = NVPTX::I32toV2I16;
456+
else if (VT.is64BitVector())
457+
Opcode = NVPTX::I64toV2I32;
458+
else
459+
llvm_unreachable("Unhandled packed type");
460+
454461
// Find and record all uses of this vector that extract element 0 or 1.
455462
SmallVector<SDNode *, 4> E0, E1;
456463
for (auto *U : Vector.getNode()->users()) {
@@ -474,11 +481,11 @@ bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {
474481
if (E0.empty() || E1.empty())
475482
return false;
476483

477-
// Merge (f16 extractelt(V, 0), f16 extractelt(V,1))
478-
// into f16,f16 SplitF16x2(V)
484+
// Merge (EltTy extractelt(V, 0), EltTy extractelt(V,1))
485+
// into EltTy,EltTy Split[EltTy]x2(V)
479486
MVT EltVT = VT.getVectorElementType();
480487
SDNode *ScatterOp =
481-
CurDAG->getMachineNode(NVPTX::I32toV2I16, SDLoc(N), EltVT, EltVT, Vector);
488+
CurDAG->getMachineNode(Opcode, SDLoc(N), EltVT, EltVT, Vector);
482489
for (auto *Node : E0)
483490
ReplaceUses(SDValue(Node, 0), SDValue(ScatterOp, 0));
484491
for (auto *Node : E1)
@@ -994,6 +1001,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, std::optional<unsigned> Opcode_i8,
9941001
case MVT::i32:
9951002
case MVT::f32:
9961003
return Opcode_i32;
1004+
case MVT::v2f32:
9971005
case MVT::i64:
9981006
case MVT::f64:
9991007
return Opcode_i64;

0 commit comments

Comments
 (0)