Skip to content

Commit 7c58476

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - Produces optimal number of operations supported by the target. Instead of progresisvely splitting the vector operand top-down, first scalarize it and then build the tree bottom-up. This uses larger operations when available and leaves smaller ones for the remaining elements. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 213d0d2 commit 7c58476

File tree

5 files changed

+1601
-516
lines changed

5 files changed

+1601
-516
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88+
static cl::opt<bool> DisableFOpTreeReduce(
89+
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90+
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91+
"reduction operations"),
92+
cl::init(false));
93+
8894
static cl::opt<NVPTX::DivPrecisionLevel> UsePrecDivF32(
8995
"nvptx-prec-divf32", cl::Hidden,
9096
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -863,6 +869,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
863869
if (STI.allowFP16Math() || STI.hasBF16Math())
864870
setTargetDAGCombine(ISD::SETCC);
865871

872+
// Vector reduction operations. These are transformed into a tree evaluation
873+
// of nodes which may or may not be legal.
874+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
875+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
876+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
877+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
878+
VT, Custom);
879+
}
880+
866881
// Promote fp16 arithmetic if fp16 hardware isn't available or the
867882
// user passed --nvptx-no-fp16-math. The flag is useful because,
868883
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1120,6 +1135,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11201135
MAKE_CASE(NVPTXISD::BFI)
11211136
MAKE_CASE(NVPTXISD::PRMT)
11221137
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1138+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1139+
MAKE_CASE(NVPTXISD::FMINNUM3)
1140+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1141+
MAKE_CASE(NVPTXISD::FMINIMUM3)
11231142
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11241143
MAKE_CASE(NVPTXISD::STACKRESTORE)
11251144
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2194,6 +2213,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21942213
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21952214
}
21962215

2216+
/// A generic routine for constructing a tree reduction for a vector operand.
2217+
/// This method differs from iterative splitting in DAGTypeLegalizer by
2218+
/// first scalarizing the vector and then progressively grouping elements
2219+
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2220+
/// with different numbers of operands (eg. max3 vs max2).
2221+
static SDValue BuildTreeReduction(
2222+
const SDValue &VectorOp,
2223+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2224+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2225+
EVT VectorTy = VectorOp.getValueType();
2226+
EVT EltTy = VectorTy.getVectorElementType();
2227+
const unsigned NumElts = VectorTy.getVectorNumElements();
2228+
2229+
// scalarize vector
2230+
SmallVector<SDValue> Elements(NumElts);
2231+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2232+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2233+
DAG.getConstant(I, DL, MVT::i64));
2234+
}
2235+
2236+
// now build the computation graph in place at each level
2237+
SmallVector<SDValue> Level = Elements;
2238+
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2239+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2240+
2241+
// partially reduce all elements in level
2242+
SmallVector<SDValue> ReducedLevel;
2243+
unsigned I = 0, E = Level.size();
2244+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2245+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2246+
ReducedLevel.push_back(DAG.getNode(
2247+
DefaultScalarOp, DL, EltTy,
2248+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2249+
}
2250+
2251+
if (I < E) {
2252+
if (ReducedLevel.empty()) {
2253+
// The current operator requires more inputs than there are operands at
2254+
// this level. Pick a smaller operator and retry.
2255+
++OpIdx;
2256+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2257+
continue;
2258+
}
2259+
2260+
// Otherwise, we just have a remainder, which we push to the next level.
2261+
for (; I < E; ++I)
2262+
ReducedLevel.push_back(Level[I]);
2263+
}
2264+
Level = ReducedLevel;
2265+
}
2266+
2267+
return *Level.begin();
2268+
}
2269+
2270+
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2271+
/// serializes it.
2272+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2273+
SelectionDAG &DAG) const {
2274+
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2275+
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2276+
return SDValue();
2277+
2278+
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2279+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2280+
STI.getPTXVersion() >= 88;
2281+
SDLoc DL(Op);
2282+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2283+
switch (Op->getOpcode()) {
2284+
case ISD::VECREDUCE_FADD:
2285+
Operators = {{ISD::FADD, 2}};
2286+
break;
2287+
case ISD::VECREDUCE_FMUL:
2288+
Operators = {{ISD::FMUL, 2}};
2289+
break;
2290+
case ISD::VECREDUCE_FMAX:
2291+
if (CanUseMinMax3)
2292+
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2293+
Operators.push_back({ISD::FMAXNUM, 2});
2294+
break;
2295+
case ISD::VECREDUCE_FMIN:
2296+
if (CanUseMinMax3)
2297+
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2298+
Operators.push_back({ISD::FMINNUM, 2});
2299+
break;
2300+
case ISD::VECREDUCE_FMAXIMUM:
2301+
if (CanUseMinMax3)
2302+
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2303+
Operators.push_back({ISD::FMAXIMUM, 2});
2304+
break;
2305+
case ISD::VECREDUCE_FMINIMUM:
2306+
if (CanUseMinMax3)
2307+
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2308+
Operators.push_back({ISD::FMINIMUM, 2});
2309+
break;
2310+
default:
2311+
llvm_unreachable("unhandled vecreduce operation");
2312+
}
2313+
2314+
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2315+
DAG);
2316+
}
2317+
21972318
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21982319
// Handle bitcasting from v2i8 without hitting the default promotion
21992320
// strategy which goes through stack memory.
@@ -3026,6 +3147,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30263147
return LowerVECTOR_SHUFFLE(Op, DAG);
30273148
case ISD::CONCAT_VECTORS:
30283149
return LowerCONCAT_VECTORS(Op, DAG);
3150+
case ISD::VECREDUCE_FADD:
3151+
case ISD::VECREDUCE_FMUL:
3152+
case ISD::VECREDUCE_FMAX:
3153+
case ISD::VECREDUCE_FMIN:
3154+
case ISD::VECREDUCE_FMAXIMUM:
3155+
case ISD::VECREDUCE_FMINIMUM:
3156+
return LowerVECREDUCE(Op, DAG);
30293157
case ISD::STORE:
30303158
return LowerSTORE(Op, DAG);
30313159
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -299,6 +304,7 @@ class NVPTXTargetLowering : public TargetLowering {
299304

300305
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
301306
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
307+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
302308
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
303309
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
304310
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
403403
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
404404
}
405405

406+
// 3-input min/max (sm_100+) for f32 only
407+
multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
408+
def f32rrr_ftz :
409+
NVPTXInst<(outs Float32Regs:$dst),
410+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
411+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
412+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
413+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
414+
def f32rri_ftz :
415+
NVPTXInst<(outs Float32Regs:$dst),
416+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
417+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
418+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
419+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
420+
def f32rii_ftz :
421+
NVPTXInst<(outs Float32Regs:$dst),
422+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
423+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
424+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
425+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
426+
def f32rrr :
427+
NVPTXInst<(outs Float32Regs:$dst),
428+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
429+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
430+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
431+
Requires<[hasPTX<88>, hasSM<100>]>;
432+
def f32rri :
433+
NVPTXInst<(outs Float32Regs:$dst),
434+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
435+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
436+
[(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
437+
Requires<[hasPTX<88>, hasSM<100>]>;
438+
def f32rii :
439+
NVPTXInst<(outs Float32Regs:$dst),
440+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
441+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
442+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
443+
Requires<[hasPTX<88>, hasSM<100>]>;
444+
}
445+
406446
// Template for instructions which take three FP args. The
407447
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
408448
//
@@ -1181,6 +1221,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
11811221
defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
11821222
defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
11831223

1224+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1225+
[SDNPCommutative, SDNPAssociative]>;
1226+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1227+
[SDNPCommutative, SDNPAssociative]>;
1228+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1229+
[SDNPCommutative, SDNPAssociative]>;
1230+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1231+
[SDNPCommutative, SDNPAssociative]>;
1232+
1233+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
1234+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
1235+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
1236+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
1237+
11841238
defm FABS : F2<"abs", fabs>;
11851239
defm FNEG : F2<"neg", fneg>;
11861240
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
8787
}
8888
unsigned getMinVectorRegisterBitWidth() const override { return 32; }
8989

90+
bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
91+
9092
// We don't want to prevent inlining because of target-cpu and -features
9193
// attributes that were added to newer versions of LLVM/Clang: There are
9294
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)