Skip to content

Commit 5fbe8bc

Browse files
committed
[NVPTX] lower VECREDUCE intrinsics to tree reduction
Also adds support for sm_100+ fmax3/fmin3 instructions, introduced in PTX 8.8. This method of tree reduction has a few benefits over the default in DAGTypeLegalizer: - The default shuffle reduction progressively halves and partially reduces the vector down until we reach a single element. This produces a sequence of operations that combine disparate elements of the vector. For example, `vecreduce_fadd <4 x f32><a b c d>` will give `(a + c) + (b + d)`, whereas the tree reduction produces (a + b) + (c + d) by grouping nearby elements together first. Both use the same number of registers, but the shuffle reduction has longer live ranges. The same example is graphed below. Note we hold onto 3 registers for 2 cycles in the shuffle reduction and 1 cycle in tree reduction. (shuffle reduction) PTX: %r1 = add.f32 a, c %r2 = add.f32 b, d %r3 = add.f32 %r1, %r3 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | | %r1 = add.f32 a, c | %r2 = add.f32 b, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | a b c [3R] | b d %r1 [3R] | %r1 %r2 [2R] | %r3 [1R] | (tree reduction) PTX: %r1 = add.f32 a, b %r2 = add.f32 c, d %r3 = add.f32 %r1, %r2 Pipeline: cycles ----> | 1 | 2 | 3 | 4 | 5 | 6 | | a = load.f32 | b = load.f32 | c = load.f32 | d = load.f32 | | | | | | %r1 = add.f32 a, b | | %r2 = add.f32 c, d | %r3 = add.f32 %r1, %r2 | live regs ----> | a [1R] | a b [2R] | c %r1 [2R] | c %r1 d [3R] | %r1 %r2 [2R] | %r3 [1R] | - The shuffle reduction cannot easily support fmax3/fmin3 because it progressively halves the input vector. - Faster compile time. Happens in one pass over the intrinsic, rather than O(N) passes if iteratively splitting the vector operands.
1 parent c04e804 commit 5fbe8bc

File tree

5 files changed

+1103
-688
lines changed

5 files changed

+1103
-688
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 296 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -853,6 +853,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
853853
if (STI.allowFP16Math() || STI.hasBF16Math())
854854
setTargetDAGCombine(ISD::SETCC);
855855

856+
// Vector reduction operations. These may be turned into sequential, shuffle,
857+
// or tree reductions depending on what instructions are available for each
858+
// type.
859+
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
860+
MVT EltVT = VT.getVectorElementType();
861+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
862+
EltVT == MVT::f64) {
863+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
864+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
865+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
866+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
867+
VT, Custom);
868+
} else if (EltVT.isScalarInteger()) {
869+
setOperationAction(
870+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
871+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
872+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
873+
VT, Custom);
874+
}
875+
}
876+
856877
// Promote fp16 arithmetic if fp16 hardware isn't available or the
857878
// user passed --nvptx-no-fp16-math. The flag is useful because,
858879
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1110,6 +1131,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
11101131
MAKE_CASE(NVPTXISD::BFI)
11111132
MAKE_CASE(NVPTXISD::PRMT)
11121133
MAKE_CASE(NVPTXISD::FCOPYSIGN)
1134+
MAKE_CASE(NVPTXISD::FMAXNUM3)
1135+
MAKE_CASE(NVPTXISD::FMINNUM3)
1136+
MAKE_CASE(NVPTXISD::FMAXIMUM3)
1137+
MAKE_CASE(NVPTXISD::FMINIMUM3)
11131138
MAKE_CASE(NVPTXISD::DYNAMIC_STACKALLOC)
11141139
MAKE_CASE(NVPTXISD::STACKRESTORE)
11151140
MAKE_CASE(NVPTXISD::STACKSAVE)
@@ -2109,6 +2134,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21092134
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21102135
}
21112136

2137+
/// A generic routine for constructing a tree reduction on a vector operand.
2138+
/// This method groups elements bottom-up, progressively building each level.
2139+
/// Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2140+
/// adjacent elements are combined first, leading to shorter live ranges. This
2141+
/// approach makes the most sense if the shuffle reduction would use the same
2142+
/// amount of registers.
2143+
///
2144+
/// The flags on the original reduction operation will be propagated to
2145+
/// each scalar operation.
2146+
static SDValue BuildTreeReduction(
2147+
const SmallVector<SDValue> &Elements, EVT EltTy,
2148+
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
2149+
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2150+
// Build the reduction tree at each level, starting with all the elements.
2151+
SmallVector<SDValue> Level = Elements;
2152+
2153+
unsigned OpIdx = 0;
2154+
while (Level.size() > 1) {
2155+
// Try to reduce this level using the current operator.
2156+
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2157+
2158+
// Build the next level by partially reducing all elements.
2159+
SmallVector<SDValue> ReducedLevel;
2160+
unsigned I = 0, E = Level.size();
2161+
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2162+
// Reduce elements in groups of [DefaultGroupSize], as much as possible.
2163+
ReducedLevel.push_back(DAG.getNode(
2164+
DefaultScalarOp, DL, EltTy,
2165+
ArrayRef<SDValue>(Level).slice(I, DefaultGroupSize), Flags));
2166+
}
2167+
2168+
if (I < E) {
2169+
// Handle leftover elements.
2170+
2171+
if (ReducedLevel.empty()) {
2172+
// We didn't reduce anything at this level. We need to pick a smaller
2173+
// operator.
2174+
++OpIdx;
2175+
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
2176+
continue;
2177+
}
2178+
2179+
// We reduced some things but there's still more left, meaning the
2180+
// operator's number of inputs doesn't evenly divide this level size. Move
2181+
// these elements to the next level.
2182+
for (; I < E; ++I)
2183+
ReducedLevel.push_back(Level[I]);
2184+
}
2185+
2186+
// Process the next level.
2187+
Level = ReducedLevel;
2188+
}
2189+
2190+
return *Level.begin();
2191+
}
2192+
2193+
/// Lower reductions to either a sequence of operations or a tree if
2194+
/// reassociations are allowed. This method will use larger operations like
2195+
/// max3/min3 when the target supports them.
2196+
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2197+
SelectionDAG &DAG) const {
2198+
SDLoc DL(Op);
2199+
const SDNodeFlags Flags = Op->getFlags();
2200+
SDValue Vector;
2201+
SDValue Accumulator;
2202+
2203+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2204+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2205+
// special case with accumulator as first arg
2206+
Accumulator = Op.getOperand(0);
2207+
Vector = Op.getOperand(1);
2208+
} else {
2209+
// default case
2210+
Vector = Op.getOperand(0);
2211+
}
2212+
2213+
EVT EltTy = Vector.getValueType().getVectorElementType();
2214+
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
2215+
STI.getPTXVersion() >= 88;
2216+
2217+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2218+
// number of inputs they take.
2219+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2220+
2221+
// Whether we can lower to scalar operations in an arbitrary order.
2222+
bool IsAssociative = allowUnsafeFPMath(DAG.getMachineFunction());
2223+
2224+
// Whether the data type and operation can be represented with fewer ops and
2225+
// registers in a shuffle reduction.
2226+
bool PrefersShuffle;
2227+
2228+
switch (Op->getOpcode()) {
2229+
case ISD::VECREDUCE_FADD:
2230+
case ISD::VECREDUCE_SEQ_FADD:
2231+
ScalarOps = {{ISD::FADD, 2}};
2232+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FADD;
2233+
// Prefer add.{,b}f16x2 for v2{,b}f16
2234+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2235+
break;
2236+
case ISD::VECREDUCE_FMUL:
2237+
case ISD::VECREDUCE_SEQ_FMUL:
2238+
ScalarOps = {{ISD::FMUL, 2}};
2239+
IsAssociative |= Op->getOpcode() == ISD::VECREDUCE_FMUL;
2240+
// Prefer mul.{,b}f16x2 for v2{,b}f16
2241+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2242+
break;
2243+
case ISD::VECREDUCE_FMAX:
2244+
if (CanUseMinMax3)
2245+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2246+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2247+
// Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2248+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2249+
IsAssociative |= Flags.hasAllowReassociation();
2250+
PrefersShuffle = false;
2251+
break;
2252+
case ISD::VECREDUCE_FMIN:
2253+
if (CanUseMinMax3)
2254+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2255+
ScalarOps.push_back({ISD::FMINNUM, 2});
2256+
// Definition of minNum in IEEE 754 2008 is non-associative due to handling
2257+
// of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2258+
IsAssociative |= Flags.hasAllowReassociation();
2259+
PrefersShuffle = false;
2260+
break;
2261+
case ISD::VECREDUCE_FMAXIMUM:
2262+
if (CanUseMinMax3) {
2263+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2264+
// Can't use fmax3 in shuffle reduction
2265+
PrefersShuffle = false;
2266+
} else {
2267+
// Prefer max.{,b}f16x2 for v2{,b}f16
2268+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2269+
}
2270+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2271+
IsAssociative = true;
2272+
break;
2273+
case ISD::VECREDUCE_FMINIMUM:
2274+
if (CanUseMinMax3) {
2275+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2276+
// Can't use fmin3 in shuffle reduction
2277+
PrefersShuffle = false;
2278+
} else {
2279+
// Prefer min.{,b}f16x2 for v2{,b}f16
2280+
PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16;
2281+
}
2282+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2283+
IsAssociative = true;
2284+
break;
2285+
case ISD::VECREDUCE_ADD:
2286+
ScalarOps = {{ISD::ADD, 2}};
2287+
IsAssociative = true;
2288+
// Prefer add.{s,u}16x2 for v2i16
2289+
PrefersShuffle = EltTy == MVT::i16;
2290+
break;
2291+
case ISD::VECREDUCE_MUL:
2292+
ScalarOps = {{ISD::MUL, 2}};
2293+
IsAssociative = true;
2294+
// Integer multiply doesn't support packed types
2295+
PrefersShuffle = false;
2296+
break;
2297+
case ISD::VECREDUCE_UMAX:
2298+
ScalarOps = {{ISD::UMAX, 2}};
2299+
IsAssociative = true;
2300+
// Prefer max.u16x2 for v2i16
2301+
PrefersShuffle = EltTy == MVT::i16;
2302+
break;
2303+
case ISD::VECREDUCE_UMIN:
2304+
ScalarOps = {{ISD::UMIN, 2}};
2305+
IsAssociative = true;
2306+
// Prefer min.u16x2 for v2i16
2307+
PrefersShuffle = EltTy == MVT::i16;
2308+
break;
2309+
case ISD::VECREDUCE_SMAX:
2310+
ScalarOps = {{ISD::SMAX, 2}};
2311+
IsAssociative = true;
2312+
// Prefer max.s16x2 for v2i16
2313+
PrefersShuffle = EltTy == MVT::i16;
2314+
break;
2315+
case ISD::VECREDUCE_SMIN:
2316+
ScalarOps = {{ISD::SMIN, 2}};
2317+
IsAssociative = true;
2318+
// Prefer min.s16x2 for v2i16
2319+
PrefersShuffle = EltTy == MVT::i16;
2320+
break;
2321+
case ISD::VECREDUCE_AND:
2322+
ScalarOps = {{ISD::AND, 2}};
2323+
IsAssociative = true;
2324+
// Prefer and.b32 for v2i16.
2325+
PrefersShuffle = EltTy == MVT::i16;
2326+
break;
2327+
case ISD::VECREDUCE_OR:
2328+
ScalarOps = {{ISD::OR, 2}};
2329+
IsAssociative = true;
2330+
// Prefer or.b32 for v2i16.
2331+
PrefersShuffle = EltTy == MVT::i16;
2332+
break;
2333+
case ISD::VECREDUCE_XOR:
2334+
ScalarOps = {{ISD::XOR, 2}};
2335+
IsAssociative = true;
2336+
// Prefer xor.b32 for v2i16.
2337+
PrefersShuffle = EltTy == MVT::i16;
2338+
break;
2339+
default:
2340+
llvm_unreachable("unhandled vecreduce operation");
2341+
}
2342+
2343+
// We don't expect an accumulator for reassociative vector reduction ops.
2344+
assert((!IsAssociative || !Accumulator) && "unexpected accumulator");
2345+
2346+
// If shuffle reduction is preferred, leave it to SelectionDAG.
2347+
if (IsAssociative && PrefersShuffle)
2348+
return SDValue();
2349+
2350+
// Otherwise, handle the reduction here.
2351+
SmallVector<SDValue> Elements;
2352+
DAG.ExtractVectorElements(Vector, Elements);
2353+
2354+
// Lower to tree reduction.
2355+
if (IsAssociative)
2356+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2357+
2358+
// Lower to sequential reduction.
2359+
EVT VectorTy = Vector.getValueType();
2360+
const unsigned NumElts = VectorTy.getVectorNumElements();
2361+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2362+
// Try to reduce the remaining sequence as much as possible using the
2363+
// current operator.
2364+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2365+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2366+
2367+
if (!Accumulator) {
2368+
// Try to initialize the accumulator using the current operator.
2369+
if (I + DefaultGroupSize <= NumElts) {
2370+
Accumulator = DAG.getNode(
2371+
DefaultScalarOp, DL, EltTy,
2372+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2373+
I += DefaultGroupSize;
2374+
}
2375+
}
2376+
2377+
if (Accumulator) {
2378+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2379+
SmallVector<SDValue> Operands = {Accumulator};
2380+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2381+
Operands.push_back(Elements[I + K]);
2382+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2383+
}
2384+
}
2385+
}
2386+
2387+
return Accumulator;
2388+
}
2389+
21122390
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
21132391
// Handle bitcasting from v2i8 without hitting the default promotion
21142392
// strategy which goes through stack memory.
@@ -2941,6 +3219,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
29413219
return LowerVECTOR_SHUFFLE(Op, DAG);
29423220
case ISD::CONCAT_VECTORS:
29433221
return LowerCONCAT_VECTORS(Op, DAG);
3222+
case ISD::VECREDUCE_FADD:
3223+
case ISD::VECREDUCE_FMUL:
3224+
case ISD::VECREDUCE_SEQ_FADD:
3225+
case ISD::VECREDUCE_SEQ_FMUL:
3226+
case ISD::VECREDUCE_FMAX:
3227+
case ISD::VECREDUCE_FMIN:
3228+
case ISD::VECREDUCE_FMAXIMUM:
3229+
case ISD::VECREDUCE_FMINIMUM:
3230+
case ISD::VECREDUCE_ADD:
3231+
case ISD::VECREDUCE_MUL:
3232+
case ISD::VECREDUCE_UMAX:
3233+
case ISD::VECREDUCE_UMIN:
3234+
case ISD::VECREDUCE_SMAX:
3235+
case ISD::VECREDUCE_SMIN:
3236+
case ISD::VECREDUCE_AND:
3237+
case ISD::VECREDUCE_OR:
3238+
case ISD::VECREDUCE_XOR:
3239+
return LowerVECREDUCE(Op, DAG);
29443240
case ISD::STORE:
29453241
return LowerSTORE(Op, DAG);
29463242
case ISD::LOAD:

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,11 @@ enum NodeType : unsigned {
7373
UNPACK_VECTOR,
7474

7575
FCOPYSIGN,
76+
FMAXNUM3,
77+
FMINNUM3,
78+
FMAXIMUM3,
79+
FMINIMUM3,
80+
7681
DYNAMIC_STACKALLOC,
7782
STACKRESTORE,
7883
STACKSAVE,
@@ -300,6 +305,7 @@ class NVPTXTargetLowering : public TargetLowering {
300305

301306
SDValue LowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG) const;
302307
SDValue LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const;
308+
SDValue LowerVECREDUCE(SDValue Op, SelectionDAG &DAG) const;
303309
SDValue LowerEXTRACT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
304310
SDValue LowerINSERT_VECTOR_ELT(SDValue Op, SelectionDAG &DAG) const;
305311
SDValue LowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG) const;

0 commit comments

Comments
 (0)