@@ -2172,19 +2172,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2172
2172
}
2173
2173
2174
2174
// / A generic routine for constructing a tree reduction on a vector operand.
2175
- // / This method differs from iterative splitting in DAGTypeLegalizer by
2176
- // / progressively grouping elements bottom-up.
2175
+ // / This method groups elements bottom-up, progressively building each level.
2176
+ // / This approach differs from top-down iterative splitting used in
2177
+ // / DAGTypeLegalizer and ExpandReductions.
2178
+ // /
2179
+ // / Also, the flags on the original reduction operation will be propagated to
2180
+ // / each scalar operation.
2177
2181
static SDValue BuildTreeReduction (
2178
2182
const SmallVector<SDValue> &Elements, EVT EltTy,
2179
2183
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2180
2184
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2181
- // now build the computation graph in place at each level
2185
+ // Build the reduction tree at each level, starting with all the elements.
2182
2186
SmallVector<SDValue> Level = Elements;
2187
+
2183
2188
unsigned OpIdx = 0 ;
2184
2189
while (Level.size () > 1 ) {
2190
+ // Try to reduce this level using the current operator.
2185
2191
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2186
2192
2187
- // partially reduce all elements in level
2193
+ // Build the next level by partially reducing all elements.
2188
2194
SmallVector<SDValue> ReducedLevel;
2189
2195
unsigned I = 0 , E = Level.size ();
2190
2196
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2195,18 +2201,23 @@ static SDValue BuildTreeReduction(
2195
2201
}
2196
2202
2197
2203
if (I < E) {
2204
+ // We have leftover elements. Why?
2205
+
2198
2206
if (ReducedLevel.empty ()) {
2199
- // The current operator requires more inputs than there are operands at
2200
- // this level . Pick a smaller operator and retry.
2207
+ // ...because this level is now so small that the current operator is
2208
+ // too big for it . Pick a smaller operator and retry.
2201
2209
++OpIdx;
2202
2210
assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2203
2211
continue ;
2204
2212
}
2205
2213
2206
- // Otherwise, we just have a remainder, which we push to the next level.
2214
+ // ...because the operator's required number of inputs doesn't divide
2215
+ // evenly this level. We push this remainder to the next level.
2207
2216
for (; I < E; ++I)
2208
2217
ReducedLevel.push_back (Level[I]);
2209
2218
}
2219
+
2220
+ // Process the next level.
2210
2221
Level = ReducedLevel;
2211
2222
}
2212
2223
@@ -2222,6 +2233,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2222
2233
const SDNodeFlags Flags = Op->getFlags ();
2223
2234
SDValue Vector;
2224
2235
SDValue Accumulator;
2236
+
2225
2237
if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2226
2238
Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2227
2239
// special case with accumulator as first arg
@@ -2231,85 +2243,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2231
2243
// default case
2232
2244
Vector = Op.getOperand (0 );
2233
2245
}
2246
+
2234
2247
EVT EltTy = Vector.getValueType ().getVectorElementType ();
2235
2248
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2236
2249
STI.getPTXVersion () >= 88 ;
2237
2250
2238
2251
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2239
2252
// number of inputs they take.
2240
2253
SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2241
- bool IsReassociatable;
2254
+
2255
+ // Whether we can lower to scalar operations in an arbitrary order.
2256
+ bool IsAssociative;
2242
2257
2243
2258
switch (Op->getOpcode ()) {
2244
2259
case ISD::VECREDUCE_FADD:
2245
2260
case ISD::VECREDUCE_SEQ_FADD:
2246
2261
ScalarOps = {{ISD::FADD, 2 }};
2247
- IsReassociatable = false ;
2262
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FADD ;
2248
2263
break ;
2249
2264
case ISD::VECREDUCE_FMUL:
2250
2265
case ISD::VECREDUCE_SEQ_FMUL:
2251
2266
ScalarOps = {{ISD::FMUL, 2 }};
2252
- IsReassociatable = false ;
2267
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FMUL ;
2253
2268
break ;
2254
2269
case ISD::VECREDUCE_FMAX:
2255
2270
if (CanUseMinMax3)
2256
2271
ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2257
2272
ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2258
- IsReassociatable = false ;
2273
+ // Definition of maxNum in IEEE 754 2008 is non-associative, but only
2274
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2275
+ // sNaNs.
2276
+ IsAssociative = true ;
2259
2277
break ;
2260
2278
case ISD::VECREDUCE_FMIN:
2261
2279
if (CanUseMinMax3)
2262
2280
ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2263
2281
ScalarOps.push_back ({ISD::FMINNUM, 2 });
2264
- IsReassociatable = false ;
2282
+ // Definition of minNum in IEEE 754 2008 is non-associative, but only
2283
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2284
+ // sNaNs.
2285
+ IsAssociative = true ;
2265
2286
break ;
2266
2287
case ISD::VECREDUCE_FMAXIMUM:
2267
2288
if (CanUseMinMax3)
2268
2289
ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2269
2290
ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2270
- IsReassociatable = false ;
2291
+ IsAssociative = true ;
2271
2292
break ;
2272
2293
case ISD::VECREDUCE_FMINIMUM:
2273
2294
if (CanUseMinMax3)
2274
2295
ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2275
2296
ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2276
- IsReassociatable = false ;
2297
+ IsAssociative = true ;
2277
2298
break ;
2278
2299
case ISD::VECREDUCE_ADD:
2279
2300
ScalarOps = {{ISD::ADD, 2 }};
2280
- IsReassociatable = true ;
2301
+ IsAssociative = true ;
2281
2302
break ;
2282
2303
case ISD::VECREDUCE_MUL:
2283
2304
ScalarOps = {{ISD::MUL, 2 }};
2284
- IsReassociatable = true ;
2305
+ IsAssociative = true ;
2285
2306
break ;
2286
2307
case ISD::VECREDUCE_UMAX:
2287
2308
ScalarOps = {{ISD::UMAX, 2 }};
2288
- IsReassociatable = true ;
2309
+ IsAssociative = true ;
2289
2310
break ;
2290
2311
case ISD::VECREDUCE_UMIN:
2291
2312
ScalarOps = {{ISD::UMIN, 2 }};
2292
- IsReassociatable = true ;
2313
+ IsAssociative = true ;
2293
2314
break ;
2294
2315
case ISD::VECREDUCE_SMAX:
2295
2316
ScalarOps = {{ISD::SMAX, 2 }};
2296
- IsReassociatable = true ;
2317
+ IsAssociative = true ;
2297
2318
break ;
2298
2319
case ISD::VECREDUCE_SMIN:
2299
2320
ScalarOps = {{ISD::SMIN, 2 }};
2300
- IsReassociatable = true ;
2321
+ IsAssociative = true ;
2301
2322
break ;
2302
2323
case ISD::VECREDUCE_AND:
2303
2324
ScalarOps = {{ISD::AND, 2 }};
2304
- IsReassociatable = true ;
2325
+ IsAssociative = true ;
2305
2326
break ;
2306
2327
case ISD::VECREDUCE_OR:
2307
2328
ScalarOps = {{ISD::OR, 2 }};
2308
- IsReassociatable = true ;
2329
+ IsAssociative = true ;
2309
2330
break ;
2310
2331
case ISD::VECREDUCE_XOR:
2311
2332
ScalarOps = {{ISD::XOR, 2 }};
2312
- IsReassociatable = true ;
2333
+ IsAssociative = true ;
2313
2334
break ;
2314
2335
default :
2315
2336
llvm_unreachable (" unhandled vecreduce operation" );
@@ -2326,18 +2347,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2326
2347
}
2327
2348
2328
2349
// Lower to tree reduction.
2329
- if (IsReassociatable || Flags. hasAllowReassociation ( )) {
2330
- // we don't expect an accumulator for reassociatable vector reduction ops
2350
+ if (IsAssociative || allowUnsafeFPMath (DAG. getMachineFunction () )) {
2351
+ // we don't expect an accumulator for reassociative vector reduction ops
2331
2352
assert (!Accumulator && " unexpected accumulator" );
2332
2353
return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2333
2354
}
2334
2355
2335
2356
// Lower to sequential reduction.
2336
2357
for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2358
+ // Try to reduce the remaining sequence as much as possible using the
2359
+ // current operator.
2337
2360
assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2338
2361
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2339
2362
2340
2363
if (!Accumulator) {
2364
+ // Try to initialize the accumulator using the current operator.
2341
2365
if (I + DefaultGroupSize <= NumElts) {
2342
2366
Accumulator = DAG.getNode (
2343
2367
DefaultScalarOp, DL, EltTy,
0 commit comments