Skip to content

Commit 60a73cc

Browse files
committed
support f32x2 instructions for Blackwell
This is a rewrite of previous work that legalized v2f32 into an i64 register. Here we keep the type non-legal, and selectively legalize it for certain operations (FADD, FSUB, FMUL, FMA). Additional operations are handled to improve codegen quality.
1 parent 9d5edc9 commit 60a73cc

File tree

10 files changed

+1308
-22
lines changed

10 files changed

+1308
-22
lines changed

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -818,6 +818,10 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
818818
def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
819819
[]>;
820820

821+
def build_pair : SDNode<"ISD::BUILD_PAIR", SDTypeProfile<1, 2,
822+
[SDTCisInt<0>, SDTCisInt<1>, SDTCisInt<2>]>, []>;
823+
824+
821825
// vector_extract/vector_insert are deprecated. extractelt/insertelt
822826
// are preferred.
823827
def vector_extract : SDNode<"ISD::EXTRACT_VECTOR_ELT",

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,12 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
121121
case NVPTXISD::SETP_BF16X2:
122122
SelectSETP_BF16X2(N);
123123
return;
124+
case NVPTXISD::FADD_F32X2:
125+
case NVPTXISD::FSUB_F32X2:
126+
case NVPTXISD::FMUL_F32X2:
127+
case NVPTXISD::FMA_F32X2:
128+
SelectF32X2Op(N);
129+
return;
124130
case NVPTXISD::LoadV2:
125131
case NVPTXISD::LoadV4:
126132
if (tryLoadVector(N))
@@ -295,6 +301,30 @@ bool NVPTXDAGToDAGISel::SelectSETP_BF16X2(SDNode *N) {
295301
return true;
296302
}
297303

304+
void NVPTXDAGToDAGISel::SelectF32X2Op(SDNode *N) {
305+
unsigned Opcode;
306+
switch (N->getOpcode()) {
307+
case NVPTXISD::FADD_F32X2:
308+
Opcode = NVPTX::FADD_F32X2;
309+
break;
310+
case NVPTXISD::FSUB_F32X2:
311+
Opcode = NVPTX::FSUB_F32X2;
312+
break;
313+
case NVPTXISD::FMUL_F32X2:
314+
Opcode = NVPTX::FMUL_F32X2;
315+
break;
316+
case NVPTXISD::FMA_F32X2:
317+
Opcode = NVPTX::FMA_F32X2;
318+
break;
319+
default:
320+
llvm_unreachable("Unexpected opcode!");
321+
}
322+
SDLoc DL(N);
323+
SmallVector<SDValue> NewOps(N->ops());
324+
SDNode *NewNode = CurDAG->getMachineNode(Opcode, DL, MVT::i64, NewOps);
325+
ReplaceNode(N, NewNode);
326+
}
327+
298328
// Find all instances of extract_vector_elt that use this v2f16 vector
299329
// and coalesce them into a scattering move instruction.
300330
bool NVPTXDAGToDAGISel::tryEXTRACT_VECTOR_ELEMENT(SDNode *N) {

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
8888
bool tryConstantFP(SDNode *N);
8989
bool SelectSETP_F16X2(SDNode *N);
9090
bool SelectSETP_BF16X2(SDNode *N);
91+
void SelectF32X2Op(SDNode *N);
9192
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
9293
void SelectV2I64toI128(SDNode *N);
9394
void SelectI128toV2I64(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 117 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -866,6 +866,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
866866
setBF16OperationAction(ISD::FNEG, MVT::v2bf16, Legal, Expand);
867867
// (would be) Library functions.
868868

869+
if (STI.hasF32x2Instructions()) {
870+
// Handle custom lowering for: v2f32 = OP v2f32, v2f32
871+
for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
872+
setOperationAction(Op, MVT::v2f32, Custom);
873+
// Handle custom lowering for: i64 = bitcast v2f32
874+
setOperationAction(ISD::BITCAST, MVT::v2f32, Custom);
875+
}
876+
869877
// These map to conversion instructions for scalar FP types.
870878
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
871879
ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1074,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10661074
MAKE_CASE(NVPTXISD::STACKSAVE)
10671075
MAKE_CASE(NVPTXISD::SETP_F16X2)
10681076
MAKE_CASE(NVPTXISD::SETP_BF16X2)
1077+
MAKE_CASE(NVPTXISD::FADD_F32X2)
1078+
MAKE_CASE(NVPTXISD::FSUB_F32X2)
1079+
MAKE_CASE(NVPTXISD::FMUL_F32X2)
1080+
MAKE_CASE(NVPTXISD::FMA_F32X2)
10691081
MAKE_CASE(NVPTXISD::Dummy)
10701082
MAKE_CASE(NVPTXISD::MUL_WIDE_SIGNED)
10711083
MAKE_CASE(NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2099,24 +2111,58 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20992111
// Handle bitcasting from v2i8 without hitting the default promotion
21002112
// strategy which goes through stack memory.
21012113
EVT FromVT = Op->getOperand(0)->getValueType(0);
2102-
if (FromVT != MVT::v2i8) {
2103-
return Op;
2104-
}
2105-
2106-
// Pack vector elements into i16 and bitcast to final type
2107-
SDLoc DL(Op);
2108-
SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2109-
Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2110-
SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2111-
Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2112-
SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2113-
SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2114-
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2115-
SDValue AsInt = DAG.getNode(
2116-
ISD::OR, DL, MVT::i16,
2117-
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
21182114
EVT ToVT = Op->getValueType(0);
2119-
return MaybeBitcast(DAG, DL, ToVT, AsInt);
2115+
SDLoc DL(Op);
2116+
2117+
if (FromVT == MVT::v2i8) {
2118+
// Pack vector elements into i16 and bitcast to final type
2119+
SDValue Vec0 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2120+
Op->getOperand(0), DAG.getIntPtrConstant(0, DL));
2121+
SDValue Vec1 = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8,
2122+
Op->getOperand(0), DAG.getIntPtrConstant(1, DL));
2123+
SDValue Extend0 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec0);
2124+
SDValue Extend1 = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i16, Vec1);
2125+
SDValue Const8 = DAG.getConstant(8, DL, MVT::i16);
2126+
SDValue AsInt = DAG.getNode(
2127+
ISD::OR, DL, MVT::i16,
2128+
{Extend0, DAG.getNode(ISD::SHL, DL, MVT::i16, {Extend1, Const8})});
2129+
EVT ToVT = Op->getValueType(0);
2130+
return MaybeBitcast(DAG, DL, ToVT, AsInt);
2131+
}
2132+
2133+
if (FromVT == MVT::v2f32) {
2134+
assert(ToVT == MVT::i64);
2135+
2136+
// A bitcast to i64 from v2f32.
2137+
// See if we can legalize the operand.
2138+
const SDValue &Operand = Op->getOperand(0);
2139+
if (Operand.getOpcode() == ISD::BUILD_VECTOR) {
2140+
const SDValue &BVOp0 = Operand.getOperand(0);
2141+
const SDValue &BVOp1 = Operand.getOperand(1);
2142+
2143+
auto CastToAPInt = [](SDValue Op) -> APInt {
2144+
if (Op->isUndef())
2145+
return APInt(64, 0); // undef values default to 0
2146+
return cast<ConstantFPSDNode>(Op)->getValueAPF().bitcastToAPInt().zext(
2147+
64);
2148+
};
2149+
2150+
if ((BVOp0->isUndef() || isa<ConstantFPSDNode>(BVOp0)) &&
2151+
(BVOp1->isUndef() || isa<ConstantFPSDNode>(BVOp1))) {
2152+
// cast two constants
2153+
APInt Value(64, 0);
2154+
Value = CastToAPInt(BVOp0) | CastToAPInt(BVOp1).shl(32);
2155+
SDValue Const = DAG.getConstant(Value, DL, MVT::i64);
2156+
return DAG.getBitcast(ToVT, Const);
2157+
}
2158+
2159+
// otherwise build an i64
2160+
return DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
2161+
DAG.getBitcast(MVT::i32, BVOp0),
2162+
DAG.getBitcast(MVT::i32, BVOp1));
2163+
}
2164+
}
2165+
return Op;
21202166
}
21212167

21222168
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
@@ -3055,6 +3101,13 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
30553101
return false;
30563102
}
30573103

3104+
const TargetRegisterClass *
3105+
NVPTXTargetLowering::getRegClassFor(MVT VT, bool isDivergent) const {
3106+
if (VT == MVT::v2f32)
3107+
return &NVPTX::Int64RegsRegClass;
3108+
return TargetLowering::getRegClassFor(VT, isDivergent);
3109+
}
3110+
30583111
// This creates target external symbol for a function parameter.
30593112
// Name of the symbol is composed from its index and the function name.
30603113
// Negative index corresponds to special parameter (unsized array) used for
@@ -5055,10 +5108,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
50555108
IsPTXVectorType(VectorVT.getSimpleVT()))
50565109
return SDValue(); // Native vector loads already combine nicely w/
50575110
// extract_vector_elt.
5058-
// Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5111+
// Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
50595112
// handle them OK.
50605113
if (VectorVT.getVectorNumElements() == 1 || Isv2x16VT(VectorVT) ||
5061-
VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5114+
VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32)
50625115
return SDValue();
50635116

50645117
// Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5478,6 +5531,45 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
54785531
Results.push_back(NewValue.getValue(3));
54795532
}
54805533

5534+
static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
5535+
SmallVectorImpl<SDValue> &Results,
5536+
bool UseFTZ) {
5537+
SDLoc DL(N);
5538+
EVT OldResultTy = N->getValueType(0); // <2 x float>
5539+
assert(OldResultTy == MVT::v2f32 && "Unexpected result type for F32x2 op!");
5540+
5541+
SmallVector<SDValue> NewOps;
5542+
5543+
// whether we use FTZ (TODO)
5544+
5545+
// replace with NVPTX F32x2 op:
5546+
unsigned Opcode;
5547+
switch (N->getOpcode()) {
5548+
case ISD::FADD:
5549+
Opcode = NVPTXISD::FADD_F32X2;
5550+
break;
5551+
case ISD::FSUB:
5552+
Opcode = NVPTXISD::FSUB_F32X2;
5553+
break;
5554+
case ISD::FMUL:
5555+
Opcode = NVPTXISD::FMUL_F32X2;
5556+
break;
5557+
case ISD::FMA:
5558+
Opcode = NVPTXISD::FMA_F32X2;
5559+
break;
5560+
default:
5561+
llvm_unreachable("Unexpected opcode");
5562+
}
5563+
5564+
// bitcast operands: <2 x float> -> i64
5565+
for (const SDValue &Op : N->ops())
5566+
NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
5567+
5568+
// cast i64 result of new op back to <2 x float>
5569+
SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
5570+
Results.push_back(DAG.getBitcast(OldResultTy, NewValue));
5571+
}
5572+
54815573
void NVPTXTargetLowering::ReplaceNodeResults(
54825574
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
54835575
switch (N->getOpcode()) {
@@ -5495,6 +5587,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
54955587
case ISD::CopyFromReg:
54965588
ReplaceCopyFromReg_128(N, DAG, Results);
54975589
return;
5590+
case ISD::FADD:
5591+
case ISD::FSUB:
5592+
case ISD::FMUL:
5593+
case ISD::FMA:
5594+
ReplaceF32x2Op(N, DAG, Results, useF32FTZ(DAG.getMachineFunction()));
5595+
return;
54985596
}
54995597
}
55005598

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,10 @@ enum NodeType : unsigned {
5555
FSHR_CLAMP,
5656
MUL_WIDE_SIGNED,
5757
MUL_WIDE_UNSIGNED,
58+
FADD_F32X2,
59+
FMUL_F32X2,
60+
FSUB_F32X2,
61+
FMA_F32X2,
5862
SETP_F16X2,
5963
SETP_BF16X2,
6064
BFE,
@@ -311,6 +315,9 @@ class NVPTXTargetLowering : public TargetLowering {
311315
SDValue *Parts, unsigned NumParts, MVT PartVT,
312316
std::optional<CallingConv::ID> CC) const override;
313317

318+
const TargetRegisterClass *getRegClassFor(MVT VT,
319+
bool isDivergent) const override;
320+
314321
void ReplaceNodeResults(SDNode *N, SmallVectorImpl<SDValue> &Results,
315322
SelectionDAG &DAG) const override;
316323
SDValue PerformDAGCombine(SDNode *N, DAGCombinerInfo &DCI) const override;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
165165
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
166166
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
167167
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
168+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
168169

169170
def True : Predicate<"true">;
170171
def False : Predicate<"false">;
@@ -2638,13 +2639,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
26382639
NVPTXInst<(outs), (ins regclass:$a), "$a",
26392640
[(LastCallArg (i32 0), vt:$a)]>;
26402641

2641-
def CallArgI64 : CallArgInst<Int64Regs>;
2642+
def CallArgI64 : CallArgInstVT<Int64Regs, i64>;
26422643
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
26432644
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
26442645
def CallArgF64 : CallArgInst<Float64Regs>;
26452646
def CallArgF32 : CallArgInst<Float32Regs>;
26462647

2647-
def LastCallArgI64 : LastCallArgInst<Int64Regs>;
2648+
def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
26482649
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
26492650
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
26502651
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -3371,6 +3372,9 @@ let hasSideEffects = false in {
33713372
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
33723373
(ins Float32Regs:$s1, Float32Regs:$s2),
33733374
"mov.b64 \t$d, {{$s1, $s2}};", []>;
3375+
def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
3376+
(ins Float32Regs:$s1, Float32Regs:$s2),
3377+
"mov.b64 \t$d, {{$s1, $s2}};", []>;
33743378

33753379
// unpack a larger int register to a set of smaller int registers
33763380
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3435,6 +3439,10 @@ def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)),
34353439
(V2I16toI32 $a, $b)>;
34363440
def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
34373441
(V2I16toI32 $a, $b)>;
3442+
def : Pat<(v2f32 (build_vector f32:$a, f32:$b)),
3443+
(V2F32toI64 $a, $b)>;
3444+
def : Pat<(i64 (build_pair i32:$a, i32:$b)),
3445+
(V2I32toI64 $a, $b)>;
34383446

34393447
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
34403448
(CVT_u32_u16 $a, CvtNONE)>;

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,6 +1581,28 @@ def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64 \t$dst, $src0, $src1;",
15811581
def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64 \t$dst, $src0, $src1;",
15821582
Float64Regs, Float64Regs, Float64Regs, int_nvvm_add_rp_d>;
15831583

1584+
// F32x2 ops (sm_100+)
1585+
1586+
def FADD_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1587+
(ins Int64Regs:$a, Int64Regs:$b),
1588+
"add.rn.f32x2 \t$res, $a, $b;", []>,
1589+
Requires<[hasF32x2Instructions]>;
1590+
1591+
def FSUB_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1592+
(ins Int64Regs:$a, Int64Regs:$b),
1593+
"sub.rn.f32x2 \t$res, $a, $b;", []>,
1594+
Requires<[hasF32x2Instructions]>;
1595+
1596+
def FMUL_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1597+
(ins Int64Regs:$a, Int64Regs:$b),
1598+
"mul.rn.f32x2 \t$res, $a, $b;", []>,
1599+
Requires<[hasF32x2Instructions]>;
1600+
1601+
def FMA_F32X2 : NVPTXInst<(outs Int64Regs:$res),
1602+
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1603+
"fma.rn.f32x2 \t$res, $a, $b;", []>,
1604+
Requires<[hasF32x2Instructions]>;
1605+
15841606
//
15851607
// BFIND
15861608
//

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6262
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
6363
(add (sequence "R%u", 0, 4),
6464
VRFrame32, VRFrameLocal32)>;
65-
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
65+
def Int64Regs : NVPTXRegClass<[i64, v2f32], 64,
66+
(add (sequence "RL%u", 0, 4),
67+
VRFrame64, VRFrameLocal64)>;
6668
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6769
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6870
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
9797
bool hasDotInstructions() const {
9898
return SmVersion >= 61 && PTXVersion >= 50;
9999
}
100+
100101
// Tcgen05 instructions in Blackwell family
101102
bool hasTcgen05Instructions() const {
102103
bool HasTcgen05 = false;
@@ -112,6 +113,8 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
112113
return HasTcgen05 && PTXVersion >= 86;
113114
}
114115

116+
bool hasF32x2Instructions() const { return SmVersion >= 100; }
117+
115118
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
116119
// terminates a basic block. Instead, it would assume that control flow
117120
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)