@@ -2166,19 +2166,25 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2166
2166
}
2167
2167
2168
2168
// / A generic routine for constructing a tree reduction on a vector operand.
2169
- // / This method differs from iterative splitting in DAGTypeLegalizer by
2170
- // / progressively grouping elements bottom-up.
2169
+ // / This method groups elements bottom-up, progressively building each level.
2170
+ // / This approach differs from top-down iterative splitting used in
2171
+ // / DAGTypeLegalizer and ExpandReductions.
2172
+ // /
2173
+ // / Also, the flags on the original reduction operation will be propagated to
2174
+ // / each scalar operation.
2171
2175
static SDValue BuildTreeReduction (
2172
2176
const SmallVector<SDValue> &Elements, EVT EltTy,
2173
2177
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2174
2178
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2175
- // now build the computation graph in place at each level
2179
+ // Build the reduction tree at each level, starting with all the elements.
2176
2180
SmallVector<SDValue> Level = Elements;
2181
+
2177
2182
unsigned OpIdx = 0 ;
2178
2183
while (Level.size () > 1 ) {
2184
+ // Try to reduce this level using the current operator.
2179
2185
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2180
2186
2181
- // partially reduce all elements in level
2187
+ // Build the next level by partially reducing all elements.
2182
2188
SmallVector<SDValue> ReducedLevel;
2183
2189
unsigned I = 0 , E = Level.size ();
2184
2190
for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
@@ -2189,18 +2195,23 @@ static SDValue BuildTreeReduction(
2189
2195
}
2190
2196
2191
2197
if (I < E) {
2198
+ // We have leftover elements. Why?
2199
+
2192
2200
if (ReducedLevel.empty ()) {
2193
- // The current operator requires more inputs than there are operands at
2194
- // this level . Pick a smaller operator and retry.
2201
+ // ...because this level is now so small that the current operator is
2202
+ // too big for it . Pick a smaller operator and retry.
2195
2203
++OpIdx;
2196
2204
assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2197
2205
continue ;
2198
2206
}
2199
2207
2200
- // Otherwise, we just have a remainder, which we push to the next level.
2208
+ // ...because the operator's required number of inputs doesn't divide
2209
+ // evenly this level. We push this remainder to the next level.
2201
2210
for (; I < E; ++I)
2202
2211
ReducedLevel.push_back (Level[I]);
2203
2212
}
2213
+
2214
+ // Process the next level.
2204
2215
Level = ReducedLevel;
2205
2216
}
2206
2217
@@ -2216,6 +2227,7 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2216
2227
const SDNodeFlags Flags = Op->getFlags ();
2217
2228
SDValue Vector;
2218
2229
SDValue Accumulator;
2230
+
2219
2231
if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2220
2232
Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2221
2233
// special case with accumulator as first arg
@@ -2225,85 +2237,94 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2225
2237
// default case
2226
2238
Vector = Op.getOperand (0 );
2227
2239
}
2240
+
2228
2241
EVT EltTy = Vector.getValueType ().getVectorElementType ();
2229
2242
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2230
2243
STI.getPTXVersion () >= 88 ;
2231
2244
2232
2245
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2233
2246
// number of inputs they take.
2234
2247
SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2235
- bool IsReassociatable;
2248
+
2249
+ // Whether we can lower to scalar operations in an arbitrary order.
2250
+ bool IsAssociative;
2236
2251
2237
2252
switch (Op->getOpcode ()) {
2238
2253
case ISD::VECREDUCE_FADD:
2239
2254
case ISD::VECREDUCE_SEQ_FADD:
2240
2255
ScalarOps = {{ISD::FADD, 2 }};
2241
- IsReassociatable = false ;
2256
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FADD ;
2242
2257
break ;
2243
2258
case ISD::VECREDUCE_FMUL:
2244
2259
case ISD::VECREDUCE_SEQ_FMUL:
2245
2260
ScalarOps = {{ISD::FMUL, 2 }};
2246
- IsReassociatable = false ;
2261
+ IsAssociative = Op-> getOpcode () == ISD::VECREDUCE_FMUL ;
2247
2262
break ;
2248
2263
case ISD::VECREDUCE_FMAX:
2249
2264
if (CanUseMinMax3)
2250
2265
ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2251
2266
ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2252
- IsReassociatable = false ;
2267
+ // Definition of maxNum in IEEE 754 2008 is non-associative, but only
2268
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2269
+ // sNaNs.
2270
+ IsAssociative = true ;
2253
2271
break ;
2254
2272
case ISD::VECREDUCE_FMIN:
2255
2273
if (CanUseMinMax3)
2256
2274
ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2257
2275
ScalarOps.push_back ({ISD::FMINNUM, 2 });
2258
- IsReassociatable = false ;
2276
+ // Definition of minNum in IEEE 754 2008 is non-associative, but only
2277
+ // because of how sNaNs are treated. However, NVIDIA GPUs don't support
2278
+ // sNaNs.
2279
+ IsAssociative = true ;
2259
2280
break ;
2260
2281
case ISD::VECREDUCE_FMAXIMUM:
2261
2282
if (CanUseMinMax3)
2262
2283
ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2263
2284
ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2264
- IsReassociatable = false ;
2285
+ IsAssociative = true ;
2265
2286
break ;
2266
2287
case ISD::VECREDUCE_FMINIMUM:
2267
2288
if (CanUseMinMax3)
2268
2289
ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2269
2290
ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2270
- IsReassociatable = false ;
2291
+ IsAssociative = true ;
2271
2292
break ;
2272
2293
case ISD::VECREDUCE_ADD:
2273
2294
ScalarOps = {{ISD::ADD, 2 }};
2274
- IsReassociatable = true ;
2295
+ IsAssociative = true ;
2275
2296
break ;
2276
2297
case ISD::VECREDUCE_MUL:
2277
2298
ScalarOps = {{ISD::MUL, 2 }};
2278
- IsReassociatable = true ;
2299
+ IsAssociative = true ;
2279
2300
break ;
2280
2301
case ISD::VECREDUCE_UMAX:
2281
2302
ScalarOps = {{ISD::UMAX, 2 }};
2282
- IsReassociatable = true ;
2303
+ IsAssociative = true ;
2283
2304
break ;
2284
2305
case ISD::VECREDUCE_UMIN:
2285
2306
ScalarOps = {{ISD::UMIN, 2 }};
2286
- IsReassociatable = true ;
2307
+ IsAssociative = true ;
2287
2308
break ;
2288
2309
case ISD::VECREDUCE_SMAX:
2289
2310
ScalarOps = {{ISD::SMAX, 2 }};
2290
- IsReassociatable = true ;
2311
+ IsAssociative = true ;
2291
2312
break ;
2292
2313
case ISD::VECREDUCE_SMIN:
2293
2314
ScalarOps = {{ISD::SMIN, 2 }};
2294
- IsReassociatable = true ;
2315
+ IsAssociative = true ;
2295
2316
break ;
2296
2317
case ISD::VECREDUCE_AND:
2297
2318
ScalarOps = {{ISD::AND, 2 }};
2298
- IsReassociatable = true ;
2319
+ IsAssociative = true ;
2299
2320
break ;
2300
2321
case ISD::VECREDUCE_OR:
2301
2322
ScalarOps = {{ISD::OR, 2 }};
2302
- IsReassociatable = true ;
2323
+ IsAssociative = true ;
2303
2324
break ;
2304
2325
case ISD::VECREDUCE_XOR:
2305
2326
ScalarOps = {{ISD::XOR, 2 }};
2306
- IsReassociatable = true ;
2327
+ IsAssociative = true ;
2307
2328
break ;
2308
2329
default :
2309
2330
llvm_unreachable (" unhandled vecreduce operation" );
@@ -2320,18 +2341,21 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
2320
2341
}
2321
2342
2322
2343
// Lower to tree reduction.
2323
- if (IsReassociatable || Flags. hasAllowReassociation ( )) {
2324
- // we don't expect an accumulator for reassociatable vector reduction ops
2344
+ if (IsAssociative || allowUnsafeFPMath (DAG. getMachineFunction () )) {
2345
+ // we don't expect an accumulator for reassociative vector reduction ops
2325
2346
assert (!Accumulator && " unexpected accumulator" );
2326
2347
return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2327
2348
}
2328
2349
2329
2350
// Lower to sequential reduction.
2330
2351
for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2352
+ // Try to reduce the remaining sequence as much as possible using the
2353
+ // current operator.
2331
2354
assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2332
2355
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2333
2356
2334
2357
if (!Accumulator) {
2358
+ // Try to initialize the accumulator using the current operator.
2335
2359
if (I + DefaultGroupSize <= NumElts) {
2336
2360
Accumulator = DAG.getNode (
2337
2361
DefaultScalarOp, DL, EltTy,
0 commit comments