@@ -841,12 +841,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841
841
setTargetDAGCombine (ISD::SETCC);
842
842
843
843
// Vector reduction operations. These are transformed into a tree evaluation
844
- // of nodes which may or may not be legal .
844
+ // of nodes which may initially be illegal .
845
845
for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
846
- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847
- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848
- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849
- VT, Custom);
846
+ MVT EltVT = VT.getVectorElementType ();
847
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
848
+ EltVT == MVT::f64 ) {
849
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
850
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
851
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
852
+ VT, Custom);
853
+ } else if (EltVT.isScalarInteger ()) {
854
+ setOperationAction (
855
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
856
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
857
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
858
+ VT, Custom);
859
+ }
850
860
}
851
861
852
862
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2166,29 +2176,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2166
2176
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2167
2177
}
2168
2178
2169
- // / A generic routine for constructing a tree reduction for a vector operand.
2179
+ // / A generic routine for constructing a tree reduction on a vector operand.
2170
2180
// / This method differs from iterative splitting in DAGTypeLegalizer by
2171
- // / first scalarizing the vector and then progressively grouping elements
2172
- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2173
- // / with different numbers of operands (eg. max3 vs max2).
2181
+ // / progressively grouping elements bottom-up.
2174
2182
static SDValue BuildTreeReduction (
2175
- const SDValue &VectorOp ,
2183
+ const SmallVector< SDValue> &Elements, EVT EltTy ,
2176
2184
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2177
2185
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2178
- EVT VectorTy = VectorOp.getValueType ();
2179
- EVT EltTy = VectorTy.getVectorElementType ();
2180
- const unsigned NumElts = VectorTy.getVectorNumElements ();
2181
-
2182
- // scalarize vector
2183
- SmallVector<SDValue> Elements (NumElts);
2184
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2185
- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2186
- DAG.getConstant (I, DL, MVT::i64 ));
2187
- }
2188
-
2189
2186
// now build the computation graph in place at each level
2190
2187
SmallVector<SDValue> Level = Elements;
2191
- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2188
+ unsigned OpIdx = 0 ;
2189
+ while (Level.size () > 1 ) {
2192
2190
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2193
2191
2194
2192
// partially reduce all elements in level
@@ -2220,52 +2218,139 @@ static SDValue BuildTreeReduction(
2220
2218
return *Level.begin ();
2221
2219
}
2222
2220
2223
- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2224
- // / serializes it.
2221
+ // / Lower reductions to either a sequence of operations or a tree if
2222
+ // / reassociations are allowed. This method will use larger operations like
2223
+ // / max3/min3 when the target supports them.
2225
2224
SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2226
2225
SelectionDAG &DAG) const {
2227
- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2228
- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2226
+ if (DisableFOpTreeReduce)
2229
2227
return SDValue ();
2230
2228
2231
- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2229
+ SDLoc DL (Op);
2230
+ const SDNodeFlags Flags = Op->getFlags ();
2231
+ const SDValue &Vector = Op.getOperand (0 );
2232
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2232
2233
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2233
2234
STI.getPTXVersion () >= 88 ;
2234
- SDLoc DL (Op);
2235
- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2235
+
2236
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2237
+ // number of inputs they take.
2238
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2239
+ bool IsReassociatable;
2240
+
2236
2241
switch (Op->getOpcode ()) {
2237
2242
case ISD::VECREDUCE_FADD:
2238
- Operators = {{ISD::FADD, 2 }};
2243
+ ScalarOps = {{ISD::FADD, 2 }};
2244
+ IsReassociatable = false ;
2239
2245
break ;
2240
2246
case ISD::VECREDUCE_FMUL:
2241
- Operators = {{ISD::FMUL, 2 }};
2247
+ ScalarOps = {{ISD::FMUL, 2 }};
2248
+ IsReassociatable = false ;
2242
2249
break ;
2243
2250
case ISD::VECREDUCE_FMAX:
2244
2251
if (CanUseMinMax3)
2245
- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2246
- Operators.push_back ({ISD::FMAXNUM, 2 });
2252
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2253
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2254
+ IsReassociatable = false ;
2247
2255
break ;
2248
2256
case ISD::VECREDUCE_FMIN:
2249
2257
if (CanUseMinMax3)
2250
- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2251
- Operators.push_back ({ISD::FMINNUM, 2 });
2258
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2259
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2260
+ IsReassociatable = false ;
2252
2261
break ;
2253
2262
case ISD::VECREDUCE_FMAXIMUM:
2254
2263
if (CanUseMinMax3)
2255
- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2256
- Operators.push_back ({ISD::FMAXIMUM, 2 });
2264
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2265
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2266
+ IsReassociatable = false ;
2257
2267
break ;
2258
2268
case ISD::VECREDUCE_FMINIMUM:
2259
2269
if (CanUseMinMax3)
2260
- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2261
- Operators.push_back ({ISD::FMINIMUM, 2 });
2270
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2271
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2272
+ IsReassociatable = false ;
2273
+ break ;
2274
+ case ISD::VECREDUCE_ADD:
2275
+ ScalarOps = {{ISD::ADD, 2 }};
2276
+ IsReassociatable = true ;
2277
+ break ;
2278
+ case ISD::VECREDUCE_MUL:
2279
+ ScalarOps = {{ISD::MUL, 2 }};
2280
+ IsReassociatable = true ;
2281
+ break ;
2282
+ case ISD::VECREDUCE_UMAX:
2283
+ ScalarOps = {{ISD::UMAX, 2 }};
2284
+ IsReassociatable = true ;
2285
+ break ;
2286
+ case ISD::VECREDUCE_UMIN:
2287
+ ScalarOps = {{ISD::UMIN, 2 }};
2288
+ IsReassociatable = true ;
2289
+ break ;
2290
+ case ISD::VECREDUCE_SMAX:
2291
+ ScalarOps = {{ISD::SMAX, 2 }};
2292
+ IsReassociatable = true ;
2293
+ break ;
2294
+ case ISD::VECREDUCE_SMIN:
2295
+ ScalarOps = {{ISD::SMIN, 2 }};
2296
+ IsReassociatable = true ;
2297
+ break ;
2298
+ case ISD::VECREDUCE_AND:
2299
+ ScalarOps = {{ISD::AND, 2 }};
2300
+ IsReassociatable = true ;
2301
+ break ;
2302
+ case ISD::VECREDUCE_OR:
2303
+ ScalarOps = {{ISD::OR, 2 }};
2304
+ IsReassociatable = true ;
2305
+ break ;
2306
+ case ISD::VECREDUCE_XOR:
2307
+ ScalarOps = {{ISD::XOR, 2 }};
2308
+ IsReassociatable = true ;
2262
2309
break ;
2263
2310
default :
2264
2311
llvm_unreachable (" unhandled vecreduce operation" );
2265
2312
}
2266
2313
2267
- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2268
- DAG);
2314
+ EVT VectorTy = Vector.getValueType ();
2315
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2316
+
2317
+ // scalarize vector
2318
+ SmallVector<SDValue> Elements (NumElts);
2319
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2320
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2321
+ DAG.getConstant (I, DL, MVT::i64 ));
2322
+ }
2323
+
2324
+ // Lower to tree reduction.
2325
+ if (IsReassociatable || Flags.hasAllowReassociation ())
2326
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2327
+
2328
+ // Lower to sequential reduction.
2329
+ SDValue Accumulator;
2330
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2331
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2332
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2333
+
2334
+ if (!Accumulator) {
2335
+ if (I + DefaultGroupSize <= NumElts) {
2336
+ Accumulator = DAG.getNode (
2337
+ DefaultScalarOp, DL, EltTy,
2338
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2339
+ I += DefaultGroupSize;
2340
+ }
2341
+ }
2342
+
2343
+ if (Accumulator) {
2344
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2345
+ SmallVector<SDValue> Operands = {Accumulator};
2346
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2347
+ Operands.push_back (Elements[I + K]);
2348
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2349
+ }
2350
+ }
2351
+ }
2352
+
2353
+ return Accumulator;
2269
2354
}
2270
2355
2271
2356
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3062,6 +3147,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3062
3147
case ISD::VECREDUCE_FMIN:
3063
3148
case ISD::VECREDUCE_FMAXIMUM:
3064
3149
case ISD::VECREDUCE_FMINIMUM:
3150
+ case ISD::VECREDUCE_ADD:
3151
+ case ISD::VECREDUCE_MUL:
3152
+ case ISD::VECREDUCE_UMAX:
3153
+ case ISD::VECREDUCE_UMIN:
3154
+ case ISD::VECREDUCE_SMAX:
3155
+ case ISD::VECREDUCE_SMIN:
3156
+ case ISD::VECREDUCE_AND:
3157
+ case ISD::VECREDUCE_OR:
3158
+ case ISD::VECREDUCE_XOR:
3065
3159
return LowerVECREDUCE (Op, DAG);
3066
3160
case ISD::STORE:
3067
3161
return LowerSTORE (Op, DAG);
0 commit comments