Skip to content

Commit 1ea7992

Browse files
committed
[NVPTX] expand associativity to fmax / fmin and add comments
1 parent c51df14 commit 1ea7992

File tree

2 files changed

+133
-109
lines changed

2 files changed

+133
-109
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2166,19 +2166,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21662166
}
21672167

21682168
/// A generic routine for constructing a tree reduction on a vector operand.
2169-
/// This method differs from iterative splitting in DAGTypeLegalizer by
2170-
/// progressively grouping elements bottom-up.
2169+
/// This method groups elements bottom-up, progressively building each level.
2170+
/// This approach differs from top-down iterative splitting used in
2171+
/// DAGTypeLegalizer and ExpandReductions.
2172+
///
2173+
/// Also, the flags on the original reduction operation will be propagated to
2174+
/// each scalar operation.
21712175
static SDValue BuildTreeReduction(
21722176
const SmallVector<SDValue> &Elements, EVT EltTy,
21732177
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21742178
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2175-
// now build the computation graph in place at each level
2179+
// Build the reduction tree at each level, starting with all the elements.
21762180
SmallVector<SDValue> Level = Elements;
2181+
21772182
unsigned OpIdx = 0;
21782183
while (Level.size() > 1) {
2184+
// Try to reduce this level using the current operator.
21792185
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21802186

2181-
// partially reduce all elements in level
2187+
// Build the next level by partially reducing all elements.
21822188
SmallVector<SDValue> ReducedLevel;
21832189
unsigned I = 0, E = Level.size();
21842190
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2189,18 +2195,23 @@ static SDValue BuildTreeReduction(
21892195
}
21902196

21912197
if (I < E) {
2198+
// We have leftover elements. Why?
2199+
21922200
if (ReducedLevel.empty()) {
2193-
// The current operator requires more inputs than there are operands at
2194-
// this level. Pick a smaller operator and retry.
2201+
// ...because this level is now so small that the current operator is
2202+
// too big for it. Pick a smaller operator and retry.
21952203
++OpIdx;
21962204
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
21972205
continue;
21982206
}
21992207

2200-
// Otherwise, we just have a remainder, which we push to the next level.
2208+
// ...because the operator's required number of inputs doesn't divide
2209+
// evenly this level. We push this remainder to the next level.
22012210
for (; I < E; ++I)
22022211
ReducedLevel.push_back(Level[I]);
22032212
}
2213+
2214+
// Process the next level.
22042215
Level = ReducedLevel;
22052216
}
22062217

@@ -2216,6 +2227,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22162227
const SDNodeFlags Flags = Op->getFlags();
22172228
SDValue Vector;
22182229
SDValue Accumulator;
2230+
22192231
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
22202232
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
22212233
// special case with accumulator as first arg
@@ -2225,85 +2237,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22252237
// default case
22262238
Vector = Op.getOperand(0);
22272239
}
2240+
22282241
EVT EltTy = Vector.getValueType().getVectorElementType();
22292242
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22302243
STI.getPTXVersion() >= 88;
22312244

22322245
// A list of SDNode opcodes with equivalent semantics, sorted descending by
22332246
// number of inputs they take.
22342247
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2235-
bool IsReassociatable;
2248+
2249+
// Whether we can lower to scalar operations in an arbitrary order.
2250+
bool IsAssociative;
22362251

22372252
switch (Op->getOpcode()) {
22382253
case ISD::VECREDUCE_FADD:
22392254
case ISD::VECREDUCE_SEQ_FADD:
22402255
ScalarOps = {{ISD::FADD, 2}};
2241-
IsReassociatable = false;
2256+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FADD;
22422257
break;
22432258
case ISD::VECREDUCE_FMUL:
22442259
case ISD::VECREDUCE_SEQ_FMUL:
22452260
ScalarOps = {{ISD::FMUL, 2}};
2246-
IsReassociatable = false;
2261+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FMUL;
22472262
break;
22482263
case ISD::VECREDUCE_FMAX:
22492264
if (CanUseMinMax3)
22502265
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
22512266
ScalarOps.push_back({ISD::FMAXNUM, 2});
2252-
IsReassociatable = false;
2267+
// Definition of maxNum in IEEE 754 2008 is non-associative, but only
2268+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2269+
// sNaNs.
2270+
IsAssociative = true;
22532271
break;
22542272
case ISD::VECREDUCE_FMIN:
22552273
if (CanUseMinMax3)
22562274
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
22572275
ScalarOps.push_back({ISD::FMINNUM, 2});
2258-
IsReassociatable = false;
2276+
// Definition of minNum in IEEE 754 2008 is non-associative, but only
2277+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2278+
// sNaNs.
2279+
IsAssociative = true;
22592280
break;
22602281
case ISD::VECREDUCE_FMAXIMUM:
22612282
if (CanUseMinMax3)
22622283
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
22632284
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2264-
IsReassociatable = false;
2285+
IsAssociative = true;
22652286
break;
22662287
case ISD::VECREDUCE_FMINIMUM:
22672288
if (CanUseMinMax3)
22682289
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
22692290
ScalarOps.push_back({ISD::FMINIMUM, 2});
2270-
IsReassociatable = false;
2291+
IsAssociative = true;
22712292
break;
22722293
case ISD::VECREDUCE_ADD:
22732294
ScalarOps = {{ISD::ADD, 2}};
2274-
IsReassociatable = true;
2295+
IsAssociative = true;
22752296
break;
22762297
case ISD::VECREDUCE_MUL:
22772298
ScalarOps = {{ISD::MUL, 2}};
2278-
IsReassociatable = true;
2299+
IsAssociative = true;
22792300
break;
22802301
case ISD::VECREDUCE_UMAX:
22812302
ScalarOps = {{ISD::UMAX, 2}};
2282-
IsReassociatable = true;
2303+
IsAssociative = true;
22832304
break;
22842305
case ISD::VECREDUCE_UMIN:
22852306
ScalarOps = {{ISD::UMIN, 2}};
2286-
IsReassociatable = true;
2307+
IsAssociative = true;
22872308
break;
22882309
case ISD::VECREDUCE_SMAX:
22892310
ScalarOps = {{ISD::SMAX, 2}};
2290-
IsReassociatable = true;
2311+
IsAssociative = true;
22912312
break;
22922313
case ISD::VECREDUCE_SMIN:
22932314
ScalarOps = {{ISD::SMIN, 2}};
2294-
IsReassociatable = true;
2315+
IsAssociative = true;
22952316
break;
22962317
case ISD::VECREDUCE_AND:
22972318
ScalarOps = {{ISD::AND, 2}};
2298-
IsReassociatable = true;
2319+
IsAssociative = true;
22992320
break;
23002321
case ISD::VECREDUCE_OR:
23012322
ScalarOps = {{ISD::OR, 2}};
2302-
IsReassociatable = true;
2323+
IsAssociative = true;
23032324
break;
23042325
case ISD::VECREDUCE_XOR:
23052326
ScalarOps = {{ISD::XOR, 2}};
2306-
IsReassociatable = true;
2327+
IsAssociative = true;
23072328
break;
23082329
default:
23092330
llvm_unreachable("unhandled vecreduce operation");
@@ -2320,18 +2341,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23202341
}
23212342

23222343
// Lower to tree reduction.
2323-
if (IsReassociatable || Flags.hasAllowReassociation()) {
2324-
// we don't expect an accumulator for reassociatable vector reduction ops
2344+
if (IsAssociative || allowUnsafeFPMath(DAG.getMachineFunction())) {
2345+
// we don't expect an accumulator for reassociative vector reduction ops
23252346
assert(!Accumulator && "unexpected accumulator");
23262347
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
23272348
}
23282349

23292350
// Lower to sequential reduction.
23302351
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2352+
// Try to reduce the remaining sequence as much as possible using the
2353+
// current operator.
23312354
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23322355
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
23332356

23342357
if (!Accumulator) {
2358+
// Try to initialize the accumulator using the current operator.
23352359
if (I + DefaultGroupSize <= NumElts) {
23362360
Accumulator = DAG.getNode(
23372361
DefaultScalarOp, DL, EltTy,

0 commit comments

Comments
 (0)