Skip to content

Commit c895233

Browse files
committed
[NVPTX] expand associativity to fmax / fmin and add comments
1 parent f8d09af commit c895233

File tree

2 files changed

+133
-109
lines changed

2 files changed

+133
-109
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2172,19 +2172,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21722172
}
21732173

21742174
/// A generic routine for constructing a tree reduction on a vector operand.
2175-
/// This method differs from iterative splitting in DAGTypeLegalizer by
2176-
/// progressively grouping elements bottom-up.
2175+
/// This method groups elements bottom-up, progressively building each level.
2176+
/// This approach differs from top-down iterative splitting used in
2177+
/// DAGTypeLegalizer and ExpandReductions.
2178+
///
2179+
/// Also, the flags on the original reduction operation will be propagated to
2180+
/// each scalar operation.
21772181
static SDValue BuildTreeReduction(
21782182
const SmallVector<SDValue> &Elements, EVT EltTy,
21792183
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21802184
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2181-
// now build the computation graph in place at each level
2185+
// Build the reduction tree at each level, starting with all the elements.
21822186
SmallVector<SDValue> Level = Elements;
2187+
21832188
unsigned OpIdx = 0;
21842189
while (Level.size() > 1) {
2190+
// Try to reduce this level using the current operator.
21852191
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21862192

2187-
// partially reduce all elements in level
2193+
// Build the next level by partially reducing all elements.
21882194
SmallVector<SDValue> ReducedLevel;
21892195
unsigned I = 0, E = Level.size();
21902196
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2195,18 +2201,23 @@ static SDValue BuildTreeReduction(
21952201
}
21962202

21972203
if (I < E) {
2204+
// We have leftover elements. Why?
2205+
21982206
if (ReducedLevel.empty()) {
2199-
// The current operator requires more inputs than there are operands at
2200-
// this level. Pick a smaller operator and retry.
2207+
// ...because this level is now so small that the current operator is
2208+
// too big for it. Pick a smaller operator and retry.
22012209
++OpIdx;
22022210
assert(OpIdx < Ops.size() && "no smaller operators for reduction");
22032211
continue;
22042212
}
22052213

2206-
// Otherwise, we just have a remainder, which we push to the next level.
2214+
// ...because the operator's required number of inputs doesn't divide
2215+
// evenly this level. We push this remainder to the next level.
22072216
for (; I < E; ++I)
22082217
ReducedLevel.push_back(Level[I]);
22092218
}
2219+
2220+
// Process the next level.
22102221
Level = ReducedLevel;
22112222
}
22122223

@@ -2222,6 +2233,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22222233
const SDNodeFlags Flags = Op->getFlags();
22232234
SDValue Vector;
22242235
SDValue Accumulator;
2236+
22252237
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
22262238
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
22272239
// special case with accumulator as first arg
@@ -2231,85 +2243,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22312243
// default case
22322244
Vector = Op.getOperand(0);
22332245
}
2246+
22342247
EVT EltTy = Vector.getValueType().getVectorElementType();
22352248
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22362249
STI.getPTXVersion() >= 88;
22372250

22382251
// A list of SDNode opcodes with equivalent semantics, sorted descending by
22392252
// number of inputs they take.
22402253
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2241-
bool IsReassociatable;
2254+
2255+
// Whether we can lower to scalar operations in an arbitrary order.
2256+
bool IsAssociative;
22422257

22432258
switch (Op->getOpcode()) {
22442259
case ISD::VECREDUCE_FADD:
22452260
case ISD::VECREDUCE_SEQ_FADD:
22462261
ScalarOps = {{ISD::FADD, 2}};
2247-
IsReassociatable = false;
2262+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FADD;
22482263
break;
22492264
case ISD::VECREDUCE_FMUL:
22502265
case ISD::VECREDUCE_SEQ_FMUL:
22512266
ScalarOps = {{ISD::FMUL, 2}};
2252-
IsReassociatable = false;
2267+
IsAssociative = Op->getOpcode() == ISD::VECREDUCE_FMUL;
22532268
break;
22542269
case ISD::VECREDUCE_FMAX:
22552270
if (CanUseMinMax3)
22562271
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
22572272
ScalarOps.push_back({ISD::FMAXNUM, 2});
2258-
IsReassociatable = false;
2273+
// Definition of maxNum in IEEE 754 2008 is non-associative, but only
2274+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2275+
// sNaNs.
2276+
IsAssociative = true;
22592277
break;
22602278
case ISD::VECREDUCE_FMIN:
22612279
if (CanUseMinMax3)
22622280
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
22632281
ScalarOps.push_back({ISD::FMINNUM, 2});
2264-
IsReassociatable = false;
2282+
// Definition of minNum in IEEE 754 2008 is non-associative, but only
2283+
// because of how sNaNs are treated. However, NVIDIA GPUs don't support
2284+
// sNaNs.
2285+
IsAssociative = true;
22652286
break;
22662287
case ISD::VECREDUCE_FMAXIMUM:
22672288
if (CanUseMinMax3)
22682289
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
22692290
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2270-
IsReassociatable = false;
2291+
IsAssociative = true;
22712292
break;
22722293
case ISD::VECREDUCE_FMINIMUM:
22732294
if (CanUseMinMax3)
22742295
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
22752296
ScalarOps.push_back({ISD::FMINIMUM, 2});
2276-
IsReassociatable = false;
2297+
IsAssociative = true;
22772298
break;
22782299
case ISD::VECREDUCE_ADD:
22792300
ScalarOps = {{ISD::ADD, 2}};
2280-
IsReassociatable = true;
2301+
IsAssociative = true;
22812302
break;
22822303
case ISD::VECREDUCE_MUL:
22832304
ScalarOps = {{ISD::MUL, 2}};
2284-
IsReassociatable = true;
2305+
IsAssociative = true;
22852306
break;
22862307
case ISD::VECREDUCE_UMAX:
22872308
ScalarOps = {{ISD::UMAX, 2}};
2288-
IsReassociatable = true;
2309+
IsAssociative = true;
22892310
break;
22902311
case ISD::VECREDUCE_UMIN:
22912312
ScalarOps = {{ISD::UMIN, 2}};
2292-
IsReassociatable = true;
2313+
IsAssociative = true;
22932314
break;
22942315
case ISD::VECREDUCE_SMAX:
22952316
ScalarOps = {{ISD::SMAX, 2}};
2296-
IsReassociatable = true;
2317+
IsAssociative = true;
22972318
break;
22982319
case ISD::VECREDUCE_SMIN:
22992320
ScalarOps = {{ISD::SMIN, 2}};
2300-
IsReassociatable = true;
2321+
IsAssociative = true;
23012322
break;
23022323
case ISD::VECREDUCE_AND:
23032324
ScalarOps = {{ISD::AND, 2}};
2304-
IsReassociatable = true;
2325+
IsAssociative = true;
23052326
break;
23062327
case ISD::VECREDUCE_OR:
23072328
ScalarOps = {{ISD::OR, 2}};
2308-
IsReassociatable = true;
2329+
IsAssociative = true;
23092330
break;
23102331
case ISD::VECREDUCE_XOR:
23112332
ScalarOps = {{ISD::XOR, 2}};
2312-
IsReassociatable = true;
2333+
IsAssociative = true;
23132334
break;
23142335
default:
23152336
llvm_unreachable("unhandled vecreduce operation");
@@ -2326,18 +2347,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23262347
}
23272348

23282349
// Lower to tree reduction.
2329-
if (IsReassociatable || Flags.hasAllowReassociation()) {
2330-
// we don't expect an accumulator for reassociatable vector reduction ops
2350+
if (IsAssociative || allowUnsafeFPMath(DAG.getMachineFunction())) {
2351+
// we don't expect an accumulator for reassociative vector reduction ops
23312352
assert(!Accumulator && "unexpected accumulator");
23322353
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
23332354
}
23342355

23352356
// Lower to sequential reduction.
23362357
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2358+
// Try to reduce the remaining sequence as much as possible using the
2359+
// current operator.
23372360
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23382361
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
23392362

23402363
if (!Accumulator) {
2364+
// Try to initialize the accumulator using the current operator.
23412365
if (I + DefaultGroupSize <= NumElts) {
23422366
Accumulator = DAG.getNode(
23432367
DefaultScalarOp, DL, EltTy,

0 commit comments

Comments
 (0)