@@ -870,12 +870,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
870
870
setTargetDAGCombine (ISD::SETCC);
871
871
872
872
// Vector reduction operations. These are transformed into a tree evaluation
873
- // of nodes which may or may not be legal .
873
+ // of nodes which may initially be illegal .
874
874
for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
875
- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
876
- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
877
- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
878
- VT, Custom);
875
+ MVT EltVT = VT.getVectorElementType ();
876
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
877
+ EltVT == MVT::f64 ) {
878
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
879
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
880
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
881
+ VT, Custom);
882
+ } else if (EltVT.isScalarInteger ()) {
883
+ setOperationAction (
884
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
885
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
886
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
887
+ VT, Custom);
888
+ }
879
889
}
880
890
881
891
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2213,29 +2223,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2213
2223
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2214
2224
}
2215
2225
2216
- // / A generic routine for constructing a tree reduction for a vector operand.
2226
+ // / A generic routine for constructing a tree reduction on a vector operand.
2217
2227
// / This method differs from iterative splitting in DAGTypeLegalizer by
2218
- // / first scalarizing the vector and then progressively grouping elements
2219
- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2220
- // / with different numbers of operands (eg. max3 vs max2).
2228
+ // / progressively grouping elements bottom-up.
2221
2229
static SDValue BuildTreeReduction (
2222
- const SDValue &VectorOp ,
2230
+ const SmallVector< SDValue> &Elements, EVT EltTy ,
2223
2231
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2224
2232
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2225
- EVT VectorTy = VectorOp.getValueType ();
2226
- EVT EltTy = VectorTy.getVectorElementType ();
2227
- const unsigned NumElts = VectorTy.getVectorNumElements ();
2228
-
2229
- // scalarize vector
2230
- SmallVector<SDValue> Elements (NumElts);
2231
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2232
- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2233
- DAG.getConstant (I, DL, MVT::i64 ));
2234
- }
2235
-
2236
2233
// now build the computation graph in place at each level
2237
2234
SmallVector<SDValue> Level = Elements;
2238
- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2235
+ unsigned OpIdx = 0 ;
2236
+ while (Level.size () > 1 ) {
2239
2237
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2240
2238
2241
2239
// partially reduce all elements in level
@@ -2267,52 +2265,139 @@ static SDValue BuildTreeReduction(
2267
2265
return *Level.begin ();
2268
2266
}
2269
2267
2270
- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2271
- // / serializes it.
2268
+ // / Lower reductions to either a sequence of operations or a tree if
2269
+ // / reassociations are allowed. This method will use larger operations like
2270
+ // / max3/min3 when the target supports them.
2272
2271
SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2273
2272
SelectionDAG &DAG) const {
2274
- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2275
- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2273
+ if (DisableFOpTreeReduce)
2276
2274
return SDValue ();
2277
2275
2278
- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2276
+ SDLoc DL (Op);
2277
+ const SDNodeFlags Flags = Op->getFlags ();
2278
+ const SDValue &Vector = Op.getOperand (0 );
2279
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2279
2280
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2280
2281
STI.getPTXVersion () >= 88 ;
2281
- SDLoc DL (Op);
2282
- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2282
+
2283
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2284
+ // number of inputs they take.
2285
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2286
+ bool IsReassociatable;
2287
+
2283
2288
switch (Op->getOpcode ()) {
2284
2289
case ISD::VECREDUCE_FADD:
2285
- Operators = {{ISD::FADD, 2 }};
2290
+ ScalarOps = {{ISD::FADD, 2 }};
2291
+ IsReassociatable = false ;
2286
2292
break ;
2287
2293
case ISD::VECREDUCE_FMUL:
2288
- Operators = {{ISD::FMUL, 2 }};
2294
+ ScalarOps = {{ISD::FMUL, 2 }};
2295
+ IsReassociatable = false ;
2289
2296
break ;
2290
2297
case ISD::VECREDUCE_FMAX:
2291
2298
if (CanUseMinMax3)
2292
- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2293
- Operators.push_back ({ISD::FMAXNUM, 2 });
2299
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2300
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2301
+ IsReassociatable = false ;
2294
2302
break ;
2295
2303
case ISD::VECREDUCE_FMIN:
2296
2304
if (CanUseMinMax3)
2297
- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2298
- Operators.push_back ({ISD::FMINNUM, 2 });
2305
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2306
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2307
+ IsReassociatable = false ;
2299
2308
break ;
2300
2309
case ISD::VECREDUCE_FMAXIMUM:
2301
2310
if (CanUseMinMax3)
2302
- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2303
- Operators.push_back ({ISD::FMAXIMUM, 2 });
2311
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2312
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2313
+ IsReassociatable = false ;
2304
2314
break ;
2305
2315
case ISD::VECREDUCE_FMINIMUM:
2306
2316
if (CanUseMinMax3)
2307
- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2308
- Operators.push_back ({ISD::FMINIMUM, 2 });
2317
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2318
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2319
+ IsReassociatable = false ;
2320
+ break ;
2321
+ case ISD::VECREDUCE_ADD:
2322
+ ScalarOps = {{ISD::ADD, 2 }};
2323
+ IsReassociatable = true ;
2324
+ break ;
2325
+ case ISD::VECREDUCE_MUL:
2326
+ ScalarOps = {{ISD::MUL, 2 }};
2327
+ IsReassociatable = true ;
2328
+ break ;
2329
+ case ISD::VECREDUCE_UMAX:
2330
+ ScalarOps = {{ISD::UMAX, 2 }};
2331
+ IsReassociatable = true ;
2332
+ break ;
2333
+ case ISD::VECREDUCE_UMIN:
2334
+ ScalarOps = {{ISD::UMIN, 2 }};
2335
+ IsReassociatable = true ;
2336
+ break ;
2337
+ case ISD::VECREDUCE_SMAX:
2338
+ ScalarOps = {{ISD::SMAX, 2 }};
2339
+ IsReassociatable = true ;
2340
+ break ;
2341
+ case ISD::VECREDUCE_SMIN:
2342
+ ScalarOps = {{ISD::SMIN, 2 }};
2343
+ IsReassociatable = true ;
2344
+ break ;
2345
+ case ISD::VECREDUCE_AND:
2346
+ ScalarOps = {{ISD::AND, 2 }};
2347
+ IsReassociatable = true ;
2348
+ break ;
2349
+ case ISD::VECREDUCE_OR:
2350
+ ScalarOps = {{ISD::OR, 2 }};
2351
+ IsReassociatable = true ;
2352
+ break ;
2353
+ case ISD::VECREDUCE_XOR:
2354
+ ScalarOps = {{ISD::XOR, 2 }};
2355
+ IsReassociatable = true ;
2309
2356
break ;
2310
2357
default :
2311
2358
llvm_unreachable (" unhandled vecreduce operation" );
2312
2359
}
2313
2360
2314
- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2315
- DAG);
2361
+ EVT VectorTy = Vector.getValueType ();
2362
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2363
+
2364
+ // scalarize vector
2365
+ SmallVector<SDValue> Elements (NumElts);
2366
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2367
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2368
+ DAG.getConstant (I, DL, MVT::i64 ));
2369
+ }
2370
+
2371
+ // Lower to tree reduction.
2372
+ if (IsReassociatable || Flags.hasAllowReassociation ())
2373
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2374
+
2375
+ // Lower to sequential reduction.
2376
+ SDValue Accumulator;
2377
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2378
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2379
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2380
+
2381
+ if (!Accumulator) {
2382
+ if (I + DefaultGroupSize <= NumElts) {
2383
+ Accumulator = DAG.getNode (
2384
+ DefaultScalarOp, DL, EltTy,
2385
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2386
+ I += DefaultGroupSize;
2387
+ }
2388
+ }
2389
+
2390
+ if (Accumulator) {
2391
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2392
+ SmallVector<SDValue> Operands = {Accumulator};
2393
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2394
+ Operands.push_back (Elements[I + K]);
2395
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2396
+ }
2397
+ }
2398
+ }
2399
+
2400
+ return Accumulator;
2316
2401
}
2317
2402
2318
2403
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3153,6 +3238,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3153
3238
case ISD::VECREDUCE_FMIN:
3154
3239
case ISD::VECREDUCE_FMAXIMUM:
3155
3240
case ISD::VECREDUCE_FMINIMUM:
3241
+ case ISD::VECREDUCE_ADD:
3242
+ case ISD::VECREDUCE_MUL:
3243
+ case ISD::VECREDUCE_UMAX:
3244
+ case ISD::VECREDUCE_UMIN:
3245
+ case ISD::VECREDUCE_SMAX:
3246
+ case ISD::VECREDUCE_SMIN:
3247
+ case ISD::VECREDUCE_AND:
3248
+ case ISD::VECREDUCE_OR:
3249
+ case ISD::VECREDUCE_XOR:
3156
3250
return LowerVECREDUCE (Op, DAG);
3157
3251
case ISD::STORE:
3158
3252
return LowerSTORE (Op, DAG);
0 commit comments