Skip to content

Commit 7a4de91

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - The default shuffle reduction progressively halves and partially reduces the vector down until we reach a single element. This produces a sequence of operations that combine disparate elements of the vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d) by grouping nearby elements together first. Both use the same number of registers, but the shuffle reduction has longer live ranges. The same example is graphed below. Note we hold onto 3 registers for 2 cycles in the shuffle reduction and 1 cycle in tree reduction. (shuffle reduction) PTX: %r1 = add.f32 a, c %r2 = add.f32 b, d %r3 = add.f32 %r1, %r3 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | a b c [3R] | b d %r1 [3R] | %r1 %r2 [2R] | %r3 [1R] | (tree reduction) PTX: %r1 = add.f32 a, b %r2 = add.f32 c, d %r3 = add.f32 %r1, %r2 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | %r1 = add.f32 a, b | | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | c %r1 [2R] | c %r1 d [3R] | %r1 %r2 [2R] | %r3 [1R] | - The shuffle reduction cannot easily support fmax3/fmin3 because it progressively halves the input vector. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent 35b8003 commit 7a4de91

File tree

5 files changed

+1103
-688
lines changed

5 files changed

+1103
-688
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -850,6 +850,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850850
if (STI.allowFP16Math() || STI.hasBF16Math())
851851
setTargetDAGCombine(ISD::SETCC);
852852

853+
// Vector reduction operations. These may be turned into sequential, shuffle,
854+
// or tree reductions depending on what instructions are available for each
855+
// type.
856+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
857+
MVT EltVT = VT.getVectorElementType();
858+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
859+
EltVT == MVT::f64) {
860+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
861+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
862+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
863+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
864+
VT, Custom);
865+
} else if (EltVT.isScalarInteger()) {
866+
setOperationAction(
867+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
868+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
869+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
870+
VT, Custom);
871+
}
872+
}
873+
853874
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854875
// user passed --nvptx-no-fp16-math. The flag is useful because,
855876
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1083,6 +1104,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
10831104
MAKE_CASE(NVPTXISD::BFI)
10841105
MAKE_CASE(NVPTXISD::PRMT)
10851106
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1107+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1108+
MAKE_CASE(NVPTXISD::FMINNUM3)
1109+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1110+
MAKE_CASE(NVPTXISD::FMINIMUM3)
10861111
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
10871112
MAKE_CASE(NVPTXISD::STACKRESTORE)
10881113
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2038,6 +2063,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
20382063
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
20392064
}
20402065

2066+
/// A generic routine for constructing a tree reduction on a vector operand.
2067+
/// This method groups elements bottom-up, progressively building each level.
2068+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2069+
/// adjacent elements are combined first, leading to shorter live ranges. This
2070+
/// approach makes the most sense if the shuffle reduction would use the same
2071+
/// amount of registers.
2072+
///
2073+
/// The flags on the original reduction operation will be propagated to
2074+
/// each scalar operation.
2075+
static SDValue BuildTreeReduction(
2076+
const SmallVector<SDValue> &Elements, EVT EltTy,
2077+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2078+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2079+
// Build the reduction tree at each level, starting with all the elements.
2080+
SmallVector<SDValue> Level = Elements;
2081+
2082+
unsigned OpIdx = 0;
2083+
while (Level.size() > 1) {
2084+
// Try to reduce this level using the current operator.
2085+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2086+
2087+
// Build the next level by partially reducing all elements.
2088+
SmallVector<SDValue> ReducedLevel;
2089+
unsigned I = 0, E = Level.size();
2090+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2091+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2092+
ReducedLevel.push_back(DAG.getNode(
2093+
DefaultScalarOp, DL, EltTy,
2094+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2095+
}
2096+
2097+
if (I < E) {
2098+
// Handle leftover elements.
2099+
2100+
if (ReducedLevel.empty()) {
2101+
// We didn't reduce anything at this level. We need to pick a smaller
2102+
// operator.
2103+
++OpIdx;
2104+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2105+
continue;
2106+
}
2107+
2108+
// We reduced some things but there's still more left, meaning the
2109+
// operator's number of inputs doesn't evenly divide this level size. Move
2110+
// these elements to the next level.
2111+
for (; I < E; ++I)
2112+
ReducedLevel.push_back(Level[I]);
2113+
}
2114+
2115+
// Process the next level.
2116+
Level = ReducedLevel;
2117+
}
2118+
2119+
return *Level.begin();
2120+
}
2121+
2122+
/// Lower reductions to either a sequence of operations or a tree if
2123+
/// reassociations are allowed. This method will use larger operations like
2124+
/// max3/min3 when the target supports them.
2125+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2126+
SelectionDAG &DAG) const {
2127+
SDLoc DL(Op);
2128+
const SDNodeFlags Flags = Op->getFlags();
2129+
SDValue Vector;
2130+
SDValue Accumulator;
2131+
2132+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2133+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2134+
// special case with accumulator as first arg
2135+
Accumulator = Op.getOperand(0);
2136+
Vector = Op.getOperand(1);
2137+
} else {
2138+
// default case
2139+
Vector = Op.getOperand(0);
2140+
}
2141+
2142+
EVT EltTy = Vector.getValueType().getVectorElementType();
2143+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2144+
STI.getPTXVersion() >= 88;
2145+
2146+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2147+
// number of inputs they take.
2148+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2149+
2150+
// Whether we can lower to scalar operations in an arbitrary order.
2151+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2152+
2153+
// Whether the data type and operation can be represented with fewer ops and
2154+
// registers in a shuffle reduction.
2155+
bool PrefersShuffle;
2156+
2157+
switch (Op->getOpcode()) {
2158+
case ISD::VECREDUCE_FADD:
2159+
case ISD::VECREDUCE_SEQ_FADD:
2160+
ScalarOps = {{ISD::FADD, 2}};
2161+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FADD;
2162+
// Prefer add.{,b}f16x2 for v2{,b}f16
2163+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2164+
break;
2165+
case ISD::VECREDUCE_FMUL:
2166+
case ISD::VECREDUCE_SEQ_FMUL:
2167+
ScalarOps = {{ISD::FMUL, 2}};
2168+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FMUL;
2169+
// Prefer mul.{,b}f16x2 for v2{,b}f16
2170+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2171+
break;
2172+
case ISD::VECREDUCE_FMAX:
2173+
if (CanUseMinMax3)
2174+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2175+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2176+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2177+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2178+
IsAssociative |= Flags.hasAllowReassociation();
2179+
PrefersShuffle = false;
2180+
break;
2181+
case ISD::VECREDUCE_FMIN:
2182+
if (CanUseMinMax3)
2183+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2184+
ScalarOps.push_back({ISD::FMINNUM, 2});
2185+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2186+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2187+
IsAssociative |= Flags.hasAllowReassociation();
2188+
PrefersShuffle = false;
2189+
break;
2190+
case ISD::VECREDUCE_FMAXIMUM:
2191+
if (CanUseMinMax3) {
2192+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2193+
// Can't use fmax3 in shuffle reduction
2194+
PrefersShuffle = false;
2195+
} else {
2196+
// Prefer max.{,b}f16x2 for v2{,b}f16
2197+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2198+
}
2199+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2200+
IsAssociative = true;
2201+
break;
2202+
case ISD::VECREDUCE_FMINIMUM:
2203+
if (CanUseMinMax3) {
2204+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2205+
// Can't use fmin3 in shuffle reduction
2206+
PrefersShuffle = false;
2207+
} else {
2208+
// Prefer min.{,b}f16x2 for v2{,b}f16
2209+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2210+
}
2211+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2212+
IsAssociative = true;
2213+
break;
2214+
case ISD::VECREDUCE_ADD:
2215+
ScalarOps = {{ISD::ADD, 2}};
2216+
IsAssociative = true;
2217+
// Prefer add.{s,u}16x2 for v2i16
2218+
PrefersShuffle = EltTy == MVT::i16;
2219+
break;
2220+
case ISD::VECREDUCE_MUL:
2221+
ScalarOps = {{ISD::MUL, 2}};
2222+
IsAssociative = true;
2223+
// Integer multiply doesn't support packed types
2224+
PrefersShuffle = false;
2225+
break;
2226+
case ISD::VECREDUCE_UMAX:
2227+
ScalarOps = {{ISD::UMAX, 2}};
2228+
IsAssociative = true;
2229+
// Prefer max.u16x2 for v2i16
2230+
PrefersShuffle = EltTy == MVT::i16;
2231+
break;
2232+
case ISD::VECREDUCE_UMIN:
2233+
ScalarOps = {{ISD::UMIN, 2}};
2234+
IsAssociative = true;
2235+
// Prefer min.u16x2 for v2i16
2236+
PrefersShuffle = EltTy == MVT::i16;
2237+
break;
2238+
case ISD::VECREDUCE_SMAX:
2239+
ScalarOps = {{ISD::SMAX, 2}};
2240+
IsAssociative = true;
2241+
// Prefer max.s16x2 for v2i16
2242+
PrefersShuffle = EltTy == MVT::i16;
2243+
break;
2244+
case ISD::VECREDUCE_SMIN:
2245+
ScalarOps = {{ISD::SMIN, 2}};
2246+
IsAssociative = true;
2247+
// Prefer min.s16x2 for v2i16
2248+
PrefersShuffle = EltTy == MVT::i16;
2249+
break;
2250+
case ISD::VECREDUCE_AND:
2251+
ScalarOps = {{ISD::AND, 2}};
2252+
IsAssociative = true;
2253+
// Prefer and.b32 for v2i16.
2254+
PrefersShuffle = EltTy == MVT::i16;
2255+
break;
2256+
case ISD::VECREDUCE_OR:
2257+
ScalarOps = {{ISD::OR, 2}};
2258+
IsAssociative = true;
2259+
// Prefer or.b32 for v2i16.
2260+
PrefersShuffle = EltTy == MVT::i16;
2261+
break;
2262+
case ISD::VECREDUCE_XOR:
2263+
ScalarOps = {{ISD::XOR, 2}};
2264+
IsAssociative = true;
2265+
// Prefer xor.b32 for v2i16.
2266+
PrefersShuffle = EltTy == MVT::i16;
2267+
break;
2268+
default:
2269+
llvm_unreachable("unhandled vecreduce operation");
2270+
}
2271+
2272+
// We don't expect an accumulator for reassociative vector reduction ops.
2273+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2274+
2275+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2276+
if (IsAssociative && PrefersShuffle)
2277+
return SDValue();
2278+
2279+
// Otherwise, handle the reduction here.
2280+
SmallVector<SDValue> Elements;
2281+
DAG.ExtractVectorElements(Vector, Elements);
2282+
2283+
// Lower to tree reduction.
2284+
if (IsAssociative)
2285+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2286+
2287+
// Lower to sequential reduction.
2288+
EVT VectorTy = Vector.getValueType();
2289+
const unsigned NumElts = VectorTy.getVectorNumElements();
2290+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2291+
// Try to reduce the remaining sequence as much as possible using the
2292+
// current operator.
2293+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2294+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2295+
2296+
if (!Accumulator) {
2297+
// Try to initialize the accumulator using the current operator.
2298+
if (I + DefaultGroupSize <= NumElts) {
2299+
Accumulator = DAG.getNode(
2300+
DefaultScalarOp, DL, EltTy,
2301+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2302+
I += DefaultGroupSize;
2303+
}
2304+
}
2305+
2306+
if (Accumulator) {
2307+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2308+
SmallVector<SDValue> Operands = {Accumulator};
2309+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2310+
Operands.push_back(Elements[I + K]);
2311+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2312+
}
2313+
}
2314+
}
2315+
2316+
return Accumulator;
2317+
}
2318+
20412319
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
20422320
// Handle bitcasting from v2i8 without hitting the default promotion
20432321
// strategy which goes through stack memory.
@@ -2869,6 +3147,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
28693147
return LowerVECTOR_SHUFFLE(Op, DAG);
28703148
case ISD::CONCAT_VECTORS:
28713149
return LowerCONCAT_VECTORS(Op, DAG);
3150+
case ISD::VECREDUCE_FADD:
3151+
case ISD::VECREDUCE_FMUL:
3152+
case ISD::VECREDUCE_SEQ_FADD:
3153+
case ISD::VECREDUCE_SEQ_FMUL:
3154+
case ISD::VECREDUCE_FMAX:
3155+
case ISD::VECREDUCE_FMIN:
3156+
case ISD::VECREDUCE_FMAXIMUM:
3157+
case ISD::VECREDUCE_FMINIMUM:
3158+
case ISD::VECREDUCE_ADD:
3159+
case ISD::VECREDUCE_MUL:
3160+
case ISD::VECREDUCE_UMAX:
3161+
case ISD::VECREDUCE_UMIN:
3162+
case ISD::VECREDUCE_SMAX:
3163+
case ISD::VECREDUCE_SMIN:
3164+
case ISD::VECREDUCE_AND:
3165+
case ISD::VECREDUCE_OR:
3166+
case ISD::VECREDUCE_XOR:
3167+
return LowerVECREDUCE(Op, DAG);
28723168
case ISD::STORE:
28733169
return LowerSTORE(Op, DAG);
28743170
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,11 @@ enum NodeType : unsigned {
6060
UNPACK_VECTOR,
6161

6262
FCOPYSIGN,
63+
FMAXNUM3,
64+
FMINNUM3,
65+
FMAXIMUM3,
66+
FMINIMUM3,
67+
6368
DYNAMIC_STACKALLOC,
6469
STACKRESTORE,
6570
STACKSAVE,
@@ -279,6 +284,7 @@ class NVPTXTargetLowering : public TargetLowering {
279284

280285
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
281286
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
287+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
282288
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
283289
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
284290
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)