@@ -853,6 +853,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
853
853
if (STI.allowFP16Math () || STI.hasBF16Math ())
854
854
setTargetDAGCombine (ISD::SETCC);
855
855
856
+ // Vector reduction operations. These may be turned into sequential, shuffle,
857
+ // or tree reductions depending on what instructions are available for each
858
+ // type.
859
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
860
+ MVT EltVT = VT.getVectorElementType ();
861
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
862
+ EltVT == MVT::f64 ) {
863
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
864
+ ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
865
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
866
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
867
+ VT, Custom);
868
+ } else if (EltVT.isScalarInteger ()) {
869
+ setOperationAction (
870
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
871
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
872
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
873
+ VT, Custom);
874
+ }
875
+ }
876
+
856
877
// Promote fp16 arithmetic if fp16 hardware isn't available or the
857
878
// user passed --nvptx-no-fp16-math. The flag is useful because,
858
879
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1110,6 +1131,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1110
1131
MAKE_CASE (NVPTXISD::BFI)
1111
1132
MAKE_CASE (NVPTXISD::PRMT)
1112
1133
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1134
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1135
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1136
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1137
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1113
1138
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1114
1139
MAKE_CASE (NVPTXISD::STACKRESTORE)
1115
1140
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2109,6 +2134,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2109
2134
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2110
2135
}
2111
2136
2137
+ // / A generic routine for constructing a tree reduction on a vector operand.
2138
+ // / This method groups elements bottom-up, progressively building each level.
2139
+ // / Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2140
+ // / adjacent elements are combined first, leading to shorter live ranges. This
2141
+ // / approach makes the most sense if the shuffle reduction would use the same
2142
+ // / amount of registers.
2143
+ // /
2144
+ // / The flags on the original reduction operation will be propagated to
2145
+ // / each scalar operation.
2146
+ static SDValue BuildTreeReduction (
2147
+ const SmallVector<SDValue> &Elements, EVT EltTy,
2148
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2149
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2150
+ // Build the reduction tree at each level, starting with all the elements.
2151
+ SmallVector<SDValue> Level = Elements;
2152
+
2153
+ unsigned OpIdx = 0 ;
2154
+ while (Level.size () > 1 ) {
2155
+ // Try to reduce this level using the current operator.
2156
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2157
+
2158
+ // Build the next level by partially reducing all elements.
2159
+ SmallVector<SDValue> ReducedLevel;
2160
+ unsigned I = 0 , E = Level.size ();
2161
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2162
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2163
+ ReducedLevel.push_back (DAG.getNode (
2164
+ DefaultScalarOp, DL, EltTy,
2165
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2166
+ }
2167
+
2168
+ if (I < E) {
2169
+ // Handle leftover elements.
2170
+
2171
+ if (ReducedLevel.empty ()) {
2172
+ // We didn't reduce anything at this level. We need to pick a smaller
2173
+ // operator.
2174
+ ++OpIdx;
2175
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2176
+ continue ;
2177
+ }
2178
+
2179
+ // We reduced some things but there's still more left, meaning the
2180
+ // operator's number of inputs doesn't evenly divide this level size. Move
2181
+ // these elements to the next level.
2182
+ for (; I < E; ++I)
2183
+ ReducedLevel.push_back (Level[I]);
2184
+ }
2185
+
2186
+ // Process the next level.
2187
+ Level = ReducedLevel;
2188
+ }
2189
+
2190
+ return *Level.begin ();
2191
+ }
2192
+
2193
+ // / Lower reductions to either a sequence of operations or a tree if
2194
+ // / reassociations are allowed. This method will use larger operations like
2195
+ // / max3/min3 when the target supports them.
2196
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2197
+ SelectionDAG &DAG) const {
2198
+ SDLoc DL (Op);
2199
+ const SDNodeFlags Flags = Op->getFlags ();
2200
+ SDValue Vector;
2201
+ SDValue Accumulator;
2202
+
2203
+ if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2204
+ Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2205
+ // special case with accumulator as first arg
2206
+ Accumulator = Op.getOperand (0 );
2207
+ Vector = Op.getOperand (1 );
2208
+ } else {
2209
+ // default case
2210
+ Vector = Op.getOperand (0 );
2211
+ }
2212
+
2213
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2214
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2215
+ STI.getPTXVersion () >= 88 ;
2216
+
2217
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2218
+ // number of inputs they take.
2219
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2220
+
2221
+ // Whether we can lower to scalar operations in an arbitrary order.
2222
+ bool IsAssociative = allowUnsafeFPMath (DAG.getMachineFunction ());
2223
+
2224
+ // Whether the data type and operation can be represented with fewer ops and
2225
+ // registers in a shuffle reduction.
2226
+ bool PrefersShuffle;
2227
+
2228
+ switch (Op->getOpcode ()) {
2229
+ case ISD::VECREDUCE_FADD:
2230
+ case ISD::VECREDUCE_SEQ_FADD:
2231
+ ScalarOps = {{ISD::FADD, 2 }};
2232
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FADD;
2233
+ // Prefer add.{,b}f16x2 for v2{,b}f16
2234
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2235
+ break ;
2236
+ case ISD::VECREDUCE_FMUL:
2237
+ case ISD::VECREDUCE_SEQ_FMUL:
2238
+ ScalarOps = {{ISD::FMUL, 2 }};
2239
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FMUL;
2240
+ // Prefer mul.{,b}f16x2 for v2{,b}f16
2241
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2242
+ break ;
2243
+ case ISD::VECREDUCE_FMAX:
2244
+ if (CanUseMinMax3)
2245
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2246
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2247
+ // Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2248
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2249
+ IsAssociative |= Flags.hasAllowReassociation ();
2250
+ PrefersShuffle = false ;
2251
+ break ;
2252
+ case ISD::VECREDUCE_FMIN:
2253
+ if (CanUseMinMax3)
2254
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2255
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2256
+ // Definition of minNum in IEEE 754 2008 is non-associative due to handling
2257
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2258
+ IsAssociative |= Flags.hasAllowReassociation ();
2259
+ PrefersShuffle = false ;
2260
+ break ;
2261
+ case ISD::VECREDUCE_FMAXIMUM:
2262
+ if (CanUseMinMax3) {
2263
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2264
+ // Can't use fmax3 in shuffle reduction
2265
+ PrefersShuffle = false ;
2266
+ } else {
2267
+ // Prefer max.{,b}f16x2 for v2{,b}f16
2268
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2269
+ }
2270
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2271
+ IsAssociative = true ;
2272
+ break ;
2273
+ case ISD::VECREDUCE_FMINIMUM:
2274
+ if (CanUseMinMax3) {
2275
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2276
+ // Can't use fmin3 in shuffle reduction
2277
+ PrefersShuffle = false ;
2278
+ } else {
2279
+ // Prefer min.{,b}f16x2 for v2{,b}f16
2280
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2281
+ }
2282
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2283
+ IsAssociative = true ;
2284
+ break ;
2285
+ case ISD::VECREDUCE_ADD:
2286
+ ScalarOps = {{ISD::ADD, 2 }};
2287
+ IsAssociative = true ;
2288
+ // Prefer add.{s,u}16x2 for v2i16
2289
+ PrefersShuffle = EltTy == MVT::i16 ;
2290
+ break ;
2291
+ case ISD::VECREDUCE_MUL:
2292
+ ScalarOps = {{ISD::MUL, 2 }};
2293
+ IsAssociative = true ;
2294
+ // Integer multiply doesn't support packed types
2295
+ PrefersShuffle = false ;
2296
+ break ;
2297
+ case ISD::VECREDUCE_UMAX:
2298
+ ScalarOps = {{ISD::UMAX, 2 }};
2299
+ IsAssociative = true ;
2300
+ // Prefer max.u16x2 for v2i16
2301
+ PrefersShuffle = EltTy == MVT::i16 ;
2302
+ break ;
2303
+ case ISD::VECREDUCE_UMIN:
2304
+ ScalarOps = {{ISD::UMIN, 2 }};
2305
+ IsAssociative = true ;
2306
+ // Prefer min.u16x2 for v2i16
2307
+ PrefersShuffle = EltTy == MVT::i16 ;
2308
+ break ;
2309
+ case ISD::VECREDUCE_SMAX:
2310
+ ScalarOps = {{ISD::SMAX, 2 }};
2311
+ IsAssociative = true ;
2312
+ // Prefer max.s16x2 for v2i16
2313
+ PrefersShuffle = EltTy == MVT::i16 ;
2314
+ break ;
2315
+ case ISD::VECREDUCE_SMIN:
2316
+ ScalarOps = {{ISD::SMIN, 2 }};
2317
+ IsAssociative = true ;
2318
+ // Prefer min.s16x2 for v2i16
2319
+ PrefersShuffle = EltTy == MVT::i16 ;
2320
+ break ;
2321
+ case ISD::VECREDUCE_AND:
2322
+ ScalarOps = {{ISD::AND, 2 }};
2323
+ IsAssociative = true ;
2324
+ // Prefer and.b32 for v2i16.
2325
+ PrefersShuffle = EltTy == MVT::i16 ;
2326
+ break ;
2327
+ case ISD::VECREDUCE_OR:
2328
+ ScalarOps = {{ISD::OR, 2 }};
2329
+ IsAssociative = true ;
2330
+ // Prefer or.b32 for v2i16.
2331
+ PrefersShuffle = EltTy == MVT::i16 ;
2332
+ break ;
2333
+ case ISD::VECREDUCE_XOR:
2334
+ ScalarOps = {{ISD::XOR, 2 }};
2335
+ IsAssociative = true ;
2336
+ // Prefer xor.b32 for v2i16.
2337
+ PrefersShuffle = EltTy == MVT::i16 ;
2338
+ break ;
2339
+ default :
2340
+ llvm_unreachable (" unhandled vecreduce operation" );
2341
+ }
2342
+
2343
+ // We don't expect an accumulator for reassociative vector reduction ops.
2344
+ assert ((!IsAssociative || !Accumulator) && " unexpected accumulator" );
2345
+
2346
+ // If shuffle reduction is preferred, leave it to SelectionDAG.
2347
+ if (IsAssociative && PrefersShuffle)
2348
+ return SDValue ();
2349
+
2350
+ // Otherwise, handle the reduction here.
2351
+ SmallVector<SDValue> Elements;
2352
+ DAG.ExtractVectorElements (Vector, Elements);
2353
+
2354
+ // Lower to tree reduction.
2355
+ if (IsAssociative)
2356
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2357
+
2358
+ // Lower to sequential reduction.
2359
+ EVT VectorTy = Vector.getValueType ();
2360
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2361
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2362
+ // Try to reduce the remaining sequence as much as possible using the
2363
+ // current operator.
2364
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2365
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2366
+
2367
+ if (!Accumulator) {
2368
+ // Try to initialize the accumulator using the current operator.
2369
+ if (I + DefaultGroupSize <= NumElts) {
2370
+ Accumulator = DAG.getNode (
2371
+ DefaultScalarOp, DL, EltTy,
2372
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2373
+ I += DefaultGroupSize;
2374
+ }
2375
+ }
2376
+
2377
+ if (Accumulator) {
2378
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2379
+ SmallVector<SDValue> Operands = {Accumulator};
2380
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2381
+ Operands.push_back (Elements[I + K]);
2382
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2383
+ }
2384
+ }
2385
+ }
2386
+
2387
+ return Accumulator;
2388
+ }
2389
+
2112
2390
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2113
2391
// Handle bitcasting from v2i8 without hitting the default promotion
2114
2392
// strategy which goes through stack memory.
@@ -2941,6 +3219,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2941
3219
return LowerVECTOR_SHUFFLE (Op, DAG);
2942
3220
case ISD::CONCAT_VECTORS:
2943
3221
return LowerCONCAT_VECTORS (Op, DAG);
3222
+ case ISD::VECREDUCE_FADD:
3223
+ case ISD::VECREDUCE_FMUL:
3224
+ case ISD::VECREDUCE_SEQ_FADD:
3225
+ case ISD::VECREDUCE_SEQ_FMUL:
3226
+ case ISD::VECREDUCE_FMAX:
3227
+ case ISD::VECREDUCE_FMIN:
3228
+ case ISD::VECREDUCE_FMAXIMUM:
3229
+ case ISD::VECREDUCE_FMINIMUM:
3230
+ case ISD::VECREDUCE_ADD:
3231
+ case ISD::VECREDUCE_MUL:
3232
+ case ISD::VECREDUCE_UMAX:
3233
+ case ISD::VECREDUCE_UMIN:
3234
+ case ISD::VECREDUCE_SMAX:
3235
+ case ISD::VECREDUCE_SMIN:
3236
+ case ISD::VECREDUCE_AND:
3237
+ case ISD::VECREDUCE_OR:
3238
+ case ISD::VECREDUCE_XOR:
3239
+ return LowerVECREDUCE (Op, DAG);
2944
3240
case ISD::STORE:
2945
3241
return LowerSTORE (Op, DAG);
2946
3242
case ISD::LOAD:
0 commit comments