Skip to content

Commit 0b53293

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - Produces optimal number of operations supported by the target. Instead of progresisvely splitting the vector operand top-down, first scalarize it and then build the tree bottom-up. This uses larger operations when available and leaves smaller ones for the remaining elements. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 9aff19e commit 0b53293

File tree

6 files changed

+892
-2
lines changed

6 files changed

+892
-2
lines changed

llvm/lib/Target/NVPTX/NVPTX.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
3636

3737
foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
3838
60, 61, 62, 70, 72, 75, 80, 86, 87,
39-
89, 90, 100, 101, 120] in
39+
89, 90, 100, 101, 103, 120, 121] in
4040
def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
4141

4242
def SM90a: FeatureSM<"90a", 901>;
4343
def SM100a: FeatureSM<"100a", 1001>;
4444
def SM101a: FeatureSM<"101a", 1011>;
45+
def SM103a: FeatureSM<"103a", 1031>;
4546
def SM120a: FeatureSM<"120a", 1201>;
47+
def SM121a: FeatureSM<"121a", 1211>;
4648

4749
foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
4850
70, 71, 72, 73, 74, 75, 76, 77, 78,
49-
80, 81, 82, 83, 84, 85, 86, 87] in
51+
80, 81, 82, 83, 84, 85, 86, 87, 88] in
5052
def PTX#version: FeaturePTX<version>;
5153

5254
//===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
8183
def : Proc<"sm_100a", [SM100a, PTX86]>;
8284
def : Proc<"sm_101", [SM101, PTX86]>;
8385
def : Proc<"sm_101a", [SM101a, PTX86]>;
86+
def : Proc<"sm_103", [SM103, PTX88]>;
87+
def : Proc<"sm_103a", [SM103a, PTX88]>;
8488
def : Proc<"sm_120", [SM120, PTX87]>;
8589
def : Proc<"sm_120a", [SM120a, PTX87]>;
90+
def : Proc<"sm_121", [SM121, PTX88]>;
91+
def : Proc<"sm_121a", [SM121a, PTX88]>;
8692

8793
def NVPTXInstrInfo : InstrInfo {
8894
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88+
static cl::opt<bool> DisableFOpTreeReduce(
89+
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90+
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91+
"reduction operations"),
92+
cl::init(false));
93+
8894
static cl::opt<int> UsePrecDivF32(
8995
"nvptx-prec-divf32", cl::Hidden,
9096
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -831,6 +837,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
831837
if (STI.allowFP16Math() || STI.hasBF16Math())
832838
setTargetDAGCombine(ISD::SETCC);
833839

840+
// Vector reduction operations. These are transformed into a tree evaluation
841+
// of nodes which may or may not be legal.
842+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
843+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846+
VT, Custom);
847+
}
848+
834849
// Promote fp16 arithmetic if fp16 hardware isn't available or the
835850
// user passed --nvptx-no-fp16-math. The flag is useful because,
836851
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1082,6 +1097,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10821097
MAKE_CASE(NVPTXISD::BFI)
10831098
MAKE_CASE(NVPTXISD::PRMT)
10841099
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1100+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1101+
MAKE_CASE(NVPTXISD::FMINNUM3)
1102+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1103+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10851104
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10861105
MAKE_CASE(NVPTXISD::STACKRESTORE)
10871106
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2136,6 +2155,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21362155
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21372156
}
21382157

2158+
/// A generic routine for constructing a tree reduction for a vector operand.
2159+
/// This method differs from iterative splitting in DAGTypeLegalizer by
2160+
/// first scalarizing the vector and then progressively grouping elements
2161+
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2162+
/// with different numbers of operands (eg. max3 vs max2).
2163+
static SDValue BuildTreeReduction(
2164+
const SDValue &VectorOp,
2165+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2166+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2167+
EVT VectorTy = VectorOp.getValueType();
2168+
EVT EltTy = VectorTy.getVectorElementType();
2169+
const unsigned NumElts = VectorTy.getVectorNumElements();
2170+
2171+
// scalarize vector
2172+
SmallVector<SDValue> Elements(NumElts);
2173+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2174+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2175+
DAG.getConstant(I, DL, MVT::i64));
2176+
}
2177+
2178+
// now build the computation graph in place at each level
2179+
SmallVector<SDValue> Level = Elements;
2180+
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2181+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2182+
2183+
// partially reduce all elements in level
2184+
SmallVector<SDValue> ReducedLevel;
2185+
unsigned I = 0, E = Level.size();
2186+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2187+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2188+
ReducedLevel.push_back(DAG.getNode(
2189+
DefaultScalarOp, DL, EltTy,
2190+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2191+
}
2192+
2193+
if (I < E) {
2194+
if (ReducedLevel.empty()) {
2195+
// The current operator requires more inputs than there are operands at
2196+
// this level. Pick a smaller operator and retry.
2197+
++OpIdx;
2198+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2199+
continue;
2200+
}
2201+
2202+
// Otherwise, we just have a remainder, which we push to the next level.
2203+
for (; I < E; ++I)
2204+
ReducedLevel.push_back(Level[I]);
2205+
}
2206+
Level = ReducedLevel;
2207+
}
2208+
2209+
return *Level.begin();
2210+
}
2211+
2212+
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2213+
/// serializes it.
2214+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2215+
SelectionDAG &DAG) const {
2216+
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2217+
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2218+
return SDValue();
2219+
2220+
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2221+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2222+
STI.getPTXVersion() >= 88;
2223+
SDLoc DL(Op);
2224+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2225+
switch (Op->getOpcode()) {
2226+
case ISD::VECREDUCE_FADD:
2227+
Operators = {{ISD::FADD, 2}};
2228+
break;
2229+
case ISD::VECREDUCE_FMUL:
2230+
Operators = {{ISD::FMUL, 2}};
2231+
break;
2232+
case ISD::VECREDUCE_FMAX:
2233+
if (CanUseMinMax3)
2234+
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2235+
Operators.push_back({ISD::FMAXNUM, 2});
2236+
break;
2237+
case ISD::VECREDUCE_FMIN:
2238+
if (CanUseMinMax3)
2239+
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2240+
Operators.push_back({ISD::FMINNUM, 2});
2241+
break;
2242+
case ISD::VECREDUCE_FMAXIMUM:
2243+
if (CanUseMinMax3)
2244+
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2245+
Operators.push_back({ISD::FMAXIMUM, 2});
2246+
break;
2247+
case ISD::VECREDUCE_FMINIMUM:
2248+
if (CanUseMinMax3)
2249+
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2250+
Operators.push_back({ISD::FMINIMUM, 2});
2251+
break;
2252+
default:
2253+
llvm_unreachable("unhandled vecreduce operation");
2254+
}
2255+
2256+
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2257+
DAG);
2258+
}
2259+
21392260
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21402261
// Handle bitcasting from v2i8 without hitting the default promotion
21412262
// strategy which goes through stack memory.
@@ -2879,6 +3000,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28793000
return LowerVECTOR_SHUFFLE(Op, DAG);
28803001
case ISD::CONCAT_VECTORS:
28813002
return LowerCONCAT_VECTORS(Op, DAG);
3003+
case ISD::VECREDUCE_FADD:
3004+
case ISD::VECREDUCE_FMUL:
3005+
case ISD::VECREDUCE_FMAX:
3006+
case ISD::VECREDUCE_FMIN:
3007+
case ISD::VECREDUCE_FMAXIMUM:
3008+
case ISD::VECREDUCE_FMINIMUM:
3009+
return LowerVECREDUCE(Op, DAG);
28823010
case ISD::STORE:
28833011
return LowerSTORE(Op, DAG);
28843012
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
296301

297302
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
298303
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
304+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
299305
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
300306
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
301307
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
368368
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
369369
}
370370

371+
// 3-input min/max (sm_100+) for f32 only
372+
multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
373+
def f32rrr_ftz :
374+
NVPTXInst<(outs Float32Regs:$dst),
375+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
376+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
377+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
378+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
379+
def f32rri_ftz :
380+
NVPTXInst<(outs Float32Regs:$dst),
381+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
382+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
383+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
384+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
385+
def f32rii_ftz :
386+
NVPTXInst<(outs Float32Regs:$dst),
387+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
388+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
389+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
390+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
391+
def f32rrr :
392+
NVPTXInst<(outs Float32Regs:$dst),
393+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
394+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
395+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
396+
Requires<[hasPTX<88>, hasSM<100>]>;
397+
def f32rri :
398+
NVPTXInst<(outs Float32Regs:$dst),
399+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
400+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
401+
[(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
402+
Requires<[hasPTX<88>, hasSM<100>]>;
403+
def f32rii :
404+
NVPTXInst<(outs Float32Regs:$dst),
405+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
406+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
407+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
408+
Requires<[hasPTX<88>, hasSM<100>]>;
409+
}
410+
371411
// Template for instructions which take three FP args. The
372412
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
373413
//
@@ -1101,6 +1141,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
11011141
defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
11021142
defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
11031143

1144+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1145+
[SDNPCommutative, SDNPAssociative]>;
1146+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1147+
[SDNPCommutative, SDNPAssociative]>;
1148+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1149+
[SDNPCommutative, SDNPAssociative]>;
1150+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1151+
[SDNPCommutative, SDNPAssociative]>;
1152+
1153+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
1154+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
1155+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
1156+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
1157+
11041158
defm FABS : F2<"abs", fabs>;
11051159
defm FNEG : F2<"neg", fneg>;
11061160
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
8383
}
8484
unsigned getMinVectorRegisterBitWidth() const { return 32; }
8585

86+
bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
87+
8688
// We don't want to prevent inlining because of target-cpu and -features
8789
// attributes that were added to newer versions of LLVM/Clang: There are
8890
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)