Skip to content

Commit 1c98e5d

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - Produces optimal number of operations supported by the target. Instead of progresisvely splitting the vector operand top-down, first scalarize it and then build the tree bottom-up. This uses larger operations when available and leaves smaller ones for the remaining elements. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 65d16a8 commit 1c98e5d

File tree

6 files changed

+892
-2
lines changed

6 files changed

+892
-2
lines changed

llvm/lib/Target/NVPTX/NVPTX.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,19 @@ class FeaturePTX<int version>:
3636

3737
foreach sm = [20, 21, 30, 32, 35, 37, 50, 52, 53,
3838
60, 61, 62, 70, 72, 75, 80, 86, 87,
39-
89, 90, 100, 101, 120] in
39+
89, 90, 100, 101, 103, 120, 121] in
4040
def SM#sm: FeatureSM<""#sm, !mul(sm, 10)>;
4141

4242
def SM90a: FeatureSM<"90a", 901>;
4343
def SM100a: FeatureSM<"100a", 1001>;
4444
def SM101a: FeatureSM<"101a", 1011>;
45+
def SM103a: FeatureSM<"103a", 1031>;
4546
def SM120a: FeatureSM<"120a", 1201>;
47+
def SM121a: FeatureSM<"121a", 1211>;
4648

4749
foreach version = [32, 40, 41, 42, 43, 50, 60, 61, 62, 63, 64, 65,
4850
70, 71, 72, 73, 74, 75, 76, 77, 78,
49-
80, 81, 82, 83, 84, 85, 86, 87] in
51+
80, 81, 82, 83, 84, 85, 86, 87, 88] in
5052
def PTX#version: FeaturePTX<version>;
5153

5254
//===----------------------------------------------------------------------===//
@@ -81,8 +83,12 @@ def : Proc<"sm_100", [SM100, PTX86]>;
8183
def : Proc<"sm_100a", [SM100a, PTX86]>;
8284
def : Proc<"sm_101", [SM101, PTX86]>;
8385
def : Proc<"sm_101a", [SM101a, PTX86]>;
86+
def : Proc<"sm_103", [SM103, PTX88]>;
87+
def : Proc<"sm_103a", [SM103a, PTX88]>;
8488
def : Proc<"sm_120", [SM120, PTX87]>;
8589
def : Proc<"sm_120a", [SM120a, PTX87]>;
90+
def : Proc<"sm_121", [SM121, PTX88]>;
91+
def : Proc<"sm_121a", [SM121a, PTX88]>;
8692

8793
def NVPTXInstrInfo : InstrInfo {
8894
}

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88+
static cl::opt<bool> DisableFOpTreeReduce(
89+
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90+
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91+
"reduction operations"),
92+
cl::init(false));
93+
8894
static cl::opt<int> UsePrecDivF32(
8995
"nvptx-prec-divf32", cl::Hidden,
9096
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -828,6 +834,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828834
if (STI.allowFP16Math() || STI.hasBF16Math())
829835
setTargetDAGCombine(ISD::SETCC);
830836

837+
// Vector reduction operations. These are transformed into a tree evaluation
838+
// of nodes which may or may not be legal.
839+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
840+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843+
VT, Custom);
844+
}
845+
831846
// Promote fp16 arithmetic if fp16 hardware isn't available or the
832847
// user passed --nvptx-no-fp16-math. The flag is useful because,
833848
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1079,6 +1094,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10791094
MAKE_CASE(NVPTXISD::BFI)
10801095
MAKE_CASE(NVPTXISD::PRMT)
10811096
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1097+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1098+
MAKE_CASE(NVPTXISD::FMINNUM3)
1099+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1100+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10821101
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10831102
MAKE_CASE(NVPTXISD::STACKRESTORE)
10841103
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2128,6 +2147,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21282147
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21292148
}
21302149

2150+
/// A generic routine for constructing a tree reduction for a vector operand.
2151+
/// This method differs from iterative splitting in DAGTypeLegalizer by
2152+
/// first scalarizing the vector and then progressively grouping elements
2153+
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2154+
/// with different numbers of operands (eg. max3 vs max2).
2155+
static SDValue BuildTreeReduction(
2156+
const SDValue &VectorOp,
2157+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2158+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2159+
EVT VectorTy = VectorOp.getValueType();
2160+
EVT EltTy = VectorTy.getVectorElementType();
2161+
const unsigned NumElts = VectorTy.getVectorNumElements();
2162+
2163+
// scalarize vector
2164+
SmallVector<SDValue> Elements(NumElts);
2165+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2166+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2167+
DAG.getConstant(I, DL, MVT::i64));
2168+
}
2169+
2170+
// now build the computation graph in place at each level
2171+
SmallVector<SDValue> Level = Elements;
2172+
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2173+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2174+
2175+
// partially reduce all elements in level
2176+
SmallVector<SDValue> ReducedLevel;
2177+
unsigned I = 0, E = Level.size();
2178+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2179+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2180+
ReducedLevel.push_back(DAG.getNode(
2181+
DefaultScalarOp, DL, EltTy,
2182+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2183+
}
2184+
2185+
if (I < E) {
2186+
if (ReducedLevel.empty()) {
2187+
// The current operator requires more inputs than there are operands at
2188+
// this level. Pick a smaller operator and retry.
2189+
++OpIdx;
2190+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2191+
continue;
2192+
}
2193+
2194+
// Otherwise, we just have a remainder, which we push to the next level.
2195+
for (; I < E; ++I)
2196+
ReducedLevel.push_back(Level[I]);
2197+
}
2198+
Level = ReducedLevel;
2199+
}
2200+
2201+
return *Level.begin();
2202+
}
2203+
2204+
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2205+
/// serializes it.
2206+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2207+
SelectionDAG &DAG) const {
2208+
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2209+
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2210+
return SDValue();
2211+
2212+
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2213+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2214+
STI.getPTXVersion() >= 88;
2215+
SDLoc DL(Op);
2216+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2217+
switch (Op->getOpcode()) {
2218+
case ISD::VECREDUCE_FADD:
2219+
Operators = {{ISD::FADD, 2}};
2220+
break;
2221+
case ISD::VECREDUCE_FMUL:
2222+
Operators = {{ISD::FMUL, 2}};
2223+
break;
2224+
case ISD::VECREDUCE_FMAX:
2225+
if (CanUseMinMax3)
2226+
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2227+
Operators.push_back({ISD::FMAXNUM, 2});
2228+
break;
2229+
case ISD::VECREDUCE_FMIN:
2230+
if (CanUseMinMax3)
2231+
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2232+
Operators.push_back({ISD::FMINNUM, 2});
2233+
break;
2234+
case ISD::VECREDUCE_FMAXIMUM:
2235+
if (CanUseMinMax3)
2236+
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2237+
Operators.push_back({ISD::FMAXIMUM, 2});
2238+
break;
2239+
case ISD::VECREDUCE_FMINIMUM:
2240+
if (CanUseMinMax3)
2241+
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2242+
Operators.push_back({ISD::FMINIMUM, 2});
2243+
break;
2244+
default:
2245+
llvm_unreachable("unhandled vecreduce operation");
2246+
}
2247+
2248+
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2249+
DAG);
2250+
}
2251+
21312252
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21322253
// Handle bitcasting from v2i8 without hitting the default promotion
21332254
// strategy which goes through stack memory.
@@ -2905,6 +3026,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29053026
return LowerVECTOR_SHUFFLE(Op, DAG);
29063027
case ISD::CONCAT_VECTORS:
29073028
return LowerCONCAT_VECTORS(Op, DAG);
3029+
case ISD::VECREDUCE_FADD:
3030+
case ISD::VECREDUCE_FMUL:
3031+
case ISD::VECREDUCE_FMAX:
3032+
case ISD::VECREDUCE_FMIN:
3033+
case ISD::VECREDUCE_FMAXIMUM:
3034+
case ISD::VECREDUCE_FMINIMUM:
3035+
return LowerVECREDUCE(Op, DAG);
29083036
case ISD::STORE:
29093037
return LowerSTORE(Op, DAG);
29103038
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -296,6 +301,7 @@ class NVPTXTargetLowering : public TargetLowering {
296301

297302
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
298303
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
304+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
299305
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
300306
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
301307
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,46 @@ multiclass FMINIMUMMAXIMUM<string OpcStr, bit NaN, SDNode OpNode> {
371371
Requires<[hasBF16Math, hasSM<80>, hasPTX<70>]>;
372372
}
373373

374+
// 3-input min/max (sm_100+) for f32 only
375+
multiclass FMINIMUMMAXIMUM3<string OpcStr, SDNode OpNode> {
376+
def f32rrr_ftz :
377+
NVPTXInst<(outs Float32Regs:$dst),
378+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
379+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
380+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
381+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
382+
def f32rri_ftz :
383+
NVPTXInst<(outs Float32Regs:$dst),
384+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
385+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
386+
[(set f32:$dst, (OpNode f32:$a, f32:$b, fpimm:$c))]>,
387+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
388+
def f32rii_ftz :
389+
NVPTXInst<(outs Float32Regs:$dst),
390+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
391+
!strconcat(OpcStr, ".ftz.f32 \t$dst, $a, $b, $c;"),
392+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
393+
Requires<[doF32FTZ, hasPTX<88>, hasSM<100>]>;
394+
def f32rrr :
395+
NVPTXInst<(outs Float32Regs:$dst),
396+
(ins Float32Regs:$a, Float32Regs:$b, Float32Regs:$c),
397+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
398+
[(set f32:$dst, (OpNode f32:$a, f32:$b, f32:$c))]>,
399+
Requires<[hasPTX<88>, hasSM<100>]>;
400+
def f32rri :
401+
NVPTXInst<(outs Float32Regs:$dst),
402+
(ins Float32Regs:$a, Float32Regs:$b, f32imm:$c),
403+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
404+
[(set f32:$dst, (OpNode f32:$a, Float32Regs:$b, fpimm:$c))]>,
405+
Requires<[hasPTX<88>, hasSM<100>]>;
406+
def f32rii :
407+
NVPTXInst<(outs Float32Regs:$dst),
408+
(ins Float32Regs:$a, f32imm:$b, f32imm:$c),
409+
!strconcat(OpcStr, ".f32 \t$dst, $a, $b, $c;"),
410+
[(set f32:$dst, (OpNode f32:$a, fpimm:$b, fpimm:$c))]>,
411+
Requires<[hasPTX<88>, hasSM<100>]>;
412+
}
413+
374414
// Template for instructions which take three FP args. The
375415
// instructions are named "<OpcStr>.f<Width>" (e.g. "add.f64").
376416
//
@@ -1139,6 +1179,20 @@ defm FMAX : FMINIMUMMAXIMUM<"max", /* NaN */ false, fmaxnum>;
11391179
defm FMINNAN : FMINIMUMMAXIMUM<"min.NaN", /* NaN */ true, fminimum>;
11401180
defm FMAXNAN : FMINIMUMMAXIMUM<"max.NaN", /* NaN */ true, fmaximum>;
11411181

1182+
def nvptx_fminnum3 : SDNode<"NVPTXISD::FMINNUM3", SDTFPTernaryOp,
1183+
[SDNPCommutative, SDNPAssociative]>;
1184+
def nvptx_fmaxnum3 : SDNode<"NVPTXISD::FMAXNUM3", SDTFPTernaryOp,
1185+
[SDNPCommutative, SDNPAssociative]>;
1186+
def nvptx_fminimum3 : SDNode<"NVPTXISD::FMINIMUM3", SDTFPTernaryOp,
1187+
[SDNPCommutative, SDNPAssociative]>;
1188+
def nvptx_fmaximum3 : SDNode<"NVPTXISD::FMAXIMUM3", SDTFPTernaryOp,
1189+
[SDNPCommutative, SDNPAssociative]>;
1190+
1191+
defm FMIN3 : FMINIMUMMAXIMUM3<"min", nvptx_fminnum3>;
1192+
defm FMAX3 : FMINIMUMMAXIMUM3<"max", nvptx_fmaxnum3>;
1193+
defm FMINNAN3 : FMINIMUMMAXIMUM3<"min.NaN", nvptx_fminimum3>;
1194+
defm FMAXNAN3 : FMINIMUMMAXIMUM3<"max.NaN", nvptx_fmaximum3>;
1195+
11421196
defm FABS : F2<"abs", fabs>;
11431197
defm FNEG : F2<"neg", fneg>;
11441198
defm FABS_H: F2_Support_Half<"abs", fabs>;

llvm/lib/Target/NVPTX/NVPTXTargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,8 @@ class NVPTXTTIImpl : public BasicTTIImplBase<NVPTXTTIImpl> {
8383
}
8484
unsigned getMinVectorRegisterBitWidth() const { return 32; }
8585

86+
bool shouldExpandReduction(const IntrinsicInst *II) const { return false; }
87+
8688
// We don't want to prevent inlining because of target-cpu and -features
8789
// attributes that were added to newer versions of LLVM/Clang: There are
8890
// no incompatible functions in PTX, ptxas will throw errors in such cases.

0 commit comments

Comments
 (0)