Skip to content

Commit 92f6f6d

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - Produces optimal number of operations supported by the target. Instead of progresisvely splitting the vector operand top-down, first scalarize it and then build the tree bottom-up. This uses larger operations when available and leaves smaller ones for the remaining elements. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 34a4c58 commit 92f6f6d

File tree

6 files changed

+515
-1756
lines changed

6 files changed

+515
-1756
lines changed

llvm/lib/Target/NVPTX/NVPTX.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
3636

3737
foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
3838
60, 61, 62, 70, 72, 75, 80, 86, 87,
39-
89, 90, 100, 101, 120] in
39+
89, 90, 100, 101, 103, 120, 121] in
4040
def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
4141

4242
def SM90a: FeatureSM<"90a", 901>;
4343
def SM100a: FeatureSM<"100a", 1001>;
4444
def SM101a: FeatureSM<"101a", 1011>;
45+
def SM103a: FeatureSM<"103a", 1031>;
4546
def SM120a: FeatureSM<"120a", 1201>;
47+
def SM121a: FeatureSM<"121a", 1211>;
4648

4749
foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
4850
70, 71, 72, 73, 74, 75, 76, 77, 78,
49-
80, 81, 82, 83, 84, 85, 86, 87] in
51+
80, 81, 82, 83, 84, 85, 86, 87, 88] in
5052
def PTX#version: FeaturePTX<version>;
5153

5254
//===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
8183
def : Proc<"sm_100a", [SM100a, PTX86]>;
8284
def : Proc<"sm_101", [SM101, PTX86]>;
8385
def : Proc<"sm_101a", [SM101a, PTX86]>;
86+
def : Proc<"sm_103", [SM103, PTX88]>;
87+
def : Proc<"sm_103a", [SM103a, PTX88]>;
8488
def : Proc<"sm_120", [SM120, PTX87]>;
8589
def : Proc<"sm_120a", [SM120a, PTX87]>;
90+
def : Proc<"sm_121", [SM121, PTX88]>;
91+
def : Proc<"sm_121a", [SM121a, PTX88]>;
8692

8793
def NVPTXInstrInfo : InstrInfo {
8894
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88+
static cl::opt<bool> DisableFOpTreeReduce(
89+
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90+
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91+
"reduction operations"),
92+
cl::init(false));
93+
8894
static cl::opt<int> UsePrecDivF32(
8995
"nvptx-prec-divf32", cl::Hidden,
9096
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -834,6 +840,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
834840
if (STI.allowFP16Math() || STI.hasBF16Math())
835841
setTargetDAGCombine(ISD::SETCC);
836842

843+
// Vector reduction operations. These are transformed into a tree evaluation
844+
// of nodes which may or may not be legal.
845+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
846+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849+
VT, Custom);
850+
}
851+
837852
// Promote fp16 arithmetic if fp16 hardware isn't available or the
838853
// user passed --nvptx-no-fp16-math. The flag is useful because,
839854
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1087,6 +1102,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10871102
MAKE_CASE(NVPTXISD::BFI)
10881103
MAKE_CASE(NVPTXISD::PRMT)
10891104
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1105+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1106+
MAKE_CASE(NVPTXISD::FMINNUM3)
1107+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1108+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10901109
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10911110
MAKE_CASE(NVPTXISD::STACKRESTORE)
10921111
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2147,6 +2166,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21472166
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21482167
}
21492168

2169+
/// A generic routine for constructing a tree reduction for a vector operand.
2170+
/// This method differs from iterative splitting in DAGTypeLegalizer by
2171+
/// first scalarizing the vector and then progressively grouping elements
2172+
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2173+
/// with different numbers of operands (eg. max3 vs max2).
2174+
static SDValue BuildTreeReduction(
2175+
const SDValue &VectorOp,
2176+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2177+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2178+
EVT VectorTy = VectorOp.getValueType();
2179+
EVT EltTy = VectorTy.getVectorElementType();
2180+
const unsigned NumElts = VectorTy.getVectorNumElements();
2181+
2182+
// scalarize vector
2183+
SmallVector<SDValue> Elements(NumElts);
2184+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2185+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2186+
DAG.getConstant(I, DL, MVT::i64));
2187+
}
2188+
2189+
// now build the computation graph in place at each level
2190+
SmallVector<SDValue> Level = Elements;
2191+
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2192+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2193+
2194+
// partially reduce all elements in level
2195+
SmallVector<SDValue> ReducedLevel;
2196+
unsigned I = 0, E = Level.size();
2197+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2198+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2199+
ReducedLevel.push_back(DAG.getNode(
2200+
DefaultScalarOp, DL, EltTy,
2201+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2202+
}
2203+
2204+
if (I < E) {
2205+
if (ReducedLevel.empty()) {
2206+
// The current operator requires more inputs than there are operands at
2207+
// this level. Pick a smaller operator and retry.
2208+
++OpIdx;
2209+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2210+
continue;
2211+
}
2212+
2213+
// Otherwise, we just have a remainder, which we push to the next level.
2214+
for (; I < E; ++I)
2215+
ReducedLevel.push_back(Level[I]);
2216+
}
2217+
Level = ReducedLevel;
2218+
}
2219+
2220+
return *Level.begin();
2221+
}
2222+
2223+
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2224+
/// serializes it.
2225+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2226+
SelectionDAG &DAG) const {
2227+
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2228+
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2229+
return SDValue();
2230+
2231+
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2232+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2233+
STI.getPTXVersion() >= 88;
2234+
SDLoc DL(Op);
2235+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2236+
switch (Op->getOpcode()) {
2237+
case ISD::VECREDUCE_FADD:
2238+
Operators = {{ISD::FADD, 2}};
2239+
break;
2240+
case ISD::VECREDUCE_FMUL:
2241+
Operators = {{ISD::FMUL, 2}};
2242+
break;
2243+
case ISD::VECREDUCE_FMAX:
2244+
if (CanUseMinMax3)
2245+
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2246+
Operators.push_back({ISD::FMAXNUM, 2});
2247+
break;
2248+
case ISD::VECREDUCE_FMIN:
2249+
if (CanUseMinMax3)
2250+
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2251+
Operators.push_back({ISD::FMINNUM, 2});
2252+
break;
2253+
case ISD::VECREDUCE_FMAXIMUM:
2254+
if (CanUseMinMax3)
2255+
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2256+
Operators.push_back({ISD::FMAXIMUM, 2});
2257+
break;
2258+
case ISD::VECREDUCE_FMINIMUM:
2259+
if (CanUseMinMax3)
2260+
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2261+
Operators.push_back({ISD::FMINIMUM, 2});
2262+
break;
2263+
default:
2264+
llvm_unreachable("unhandled vecreduce operation");
2265+
}
2266+
2267+
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2268+
DAG);
2269+
}
2270+
21502271
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21512272
// Handle bitcasting from v2i8 without hitting the default promotion
21522273
// strategy which goes through stack memory.
@@ -2935,6 +3056,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29353056
return LowerVECTOR_SHUFFLE(Op, DAG);
29363057
case ISD::CONCAT_VECTORS:
29373058
return LowerCONCAT_VECTORS(Op, DAG);
3059+
case ISD::VECREDUCE_FADD:
3060+
case ISD::VECREDUCE_FMUL:
3061+
case ISD::VECREDUCE_FMAX:
3062+
case ISD::VECREDUCE_FMIN:
3063+
case ISD::VECREDUCE_FMAXIMUM:
3064+
case ISD::VECREDUCE_FMINIMUM:
3065+
return LowerVECREDUCE(Op, DAG);
29383066
case ISD::STORE:
29393067
return LowerSTORE(Op, DAG);
29403068
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
296301

297302
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
298303
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
304+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
299305
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
300306
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
301307
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
372372
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
373373
}
374374

375+
// 3-input min/max (sm_100+) for f32 only
376+
multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
377+
def f32rrr_ftz :
378+
NVPTXInst<(outs Float32Regs:$dst),
379+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
380+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
381+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
382+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
383+
def f32rri_ftz :
384+
NVPTXInst<(outs Float32Regs:$dst),
385+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
386+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
387+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
388+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
389+
def f32rii_ftz :
390+
NVPTXInst<(outs Float32Regs:$dst),
391+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
392+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
393+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
394+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
395+
def f32rrr :
396+
NVPTXInst<(outs Float32Regs:$dst),
397+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
398+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
399+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
400+
Requires<[hasPTX<88>, hasSM<100>]>;
401+
def f32rri :
402+
NVPTXInst<(outs Float32Regs:$dst),
403+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
404+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
405+
[(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
406+
Requires<[hasPTX<88>, hasSM<100>]>;
407+
def f32rii :
408+
NVPTXInst<(outs Float32Regs:$dst),
409+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
410+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
411+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
412+
Requires<[hasPTX<88>, hasSM<100>]>;
413+
}
414+
375415
// Template for instructions which take three FP args. The
376416
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
377417
//
@@ -1140,6 +1180,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
11401180
defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
11411181
defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
11421182

1183+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1184+
[SDNPCommutative, SDNPAssociative]>;
1185+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1186+
[SDNPCommutative, SDNPAssociative]>;
1187+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1188+
[SDNPCommutative, SDNPAssociative]>;
1189+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1190+
[SDNPCommutative, SDNPAssociative]>;
1191+
1192+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
1193+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
1194+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
1195+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
1196+
11431197
defm FABS : F2<"abs", fabs>;
11441198
defm FNEG : F2<"neg", fneg>;
11451199
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
8383
}
8484
unsigned getMinVectorRegisterBitWidth() const { return 32; }
8585

86+
bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
87+
8688
// We don't want to prevent inlining because of target-cpu and -features
8789
// attributes that were added to newer versions of LLVM/Clang: There are
8890
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)