@@ -838,12 +838,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
838
838
setTargetDAGCombine (ISD::SETCC);
839
839
840
840
// Vector reduction operations. These are transformed into a tree evaluation
841
- // of nodes which may or may not be legal .
841
+ // of nodes which may initially be illegal .
842
842
for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
843
- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844
- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845
- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846
- VT, Custom);
843
+ MVT EltVT = VT.getVectorElementType ();
844
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
845
+ EltVT == MVT::f64 ) {
846
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849
+ VT, Custom);
850
+ } else if (EltVT.isScalarInteger ()) {
851
+ setOperationAction (
852
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
853
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
854
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
855
+ VT, Custom);
856
+ }
847
857
}
848
858
849
859
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2155,29 +2165,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2155
2165
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2156
2166
}
2157
2167
2158
- // / A generic routine for constructing a tree reduction for a vector operand.
2168
+ // / A generic routine for constructing a tree reduction on a vector operand.
2159
2169
// / This method differs from iterative splitting in DAGTypeLegalizer by
2160
- // / first scalarizing the vector and then progressively grouping elements
2161
- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2162
- // / with different numbers of operands (eg. max3 vs max2).
2170
+ // / progressively grouping elements bottom-up.
2163
2171
static SDValue BuildTreeReduction (
2164
- const SDValue &VectorOp ,
2172
+ const SmallVector< SDValue> &Elements, EVT EltTy ,
2165
2173
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2166
2174
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2167
- EVT VectorTy = VectorOp.getValueType ();
2168
- EVT EltTy = VectorTy.getVectorElementType ();
2169
- const unsigned NumElts = VectorTy.getVectorNumElements ();
2170
-
2171
- // scalarize vector
2172
- SmallVector<SDValue> Elements (NumElts);
2173
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2174
- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2175
- DAG.getConstant (I, DL, MVT::i64 ));
2176
- }
2177
-
2178
2175
// now build the computation graph in place at each level
2179
2176
SmallVector<SDValue> Level = Elements;
2180
- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2177
+ unsigned OpIdx = 0 ;
2178
+ while (Level.size () > 1 ) {
2181
2179
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2182
2180
2183
2181
// partially reduce all elements in level
@@ -2209,52 +2207,139 @@ static SDValue BuildTreeReduction(
2209
2207
return *Level.begin ();
2210
2208
}
2211
2209
2212
- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2213
- // / serializes it.
2210
+ // / Lower reductions to either a sequence of operations or a tree if
2211
+ // / reassociations are allowed. This method will use larger operations like
2212
+ // / max3/min3 when the target supports them.
2214
2213
SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2215
2214
SelectionDAG &DAG) const {
2216
- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2217
- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2215
+ if (DisableFOpTreeReduce)
2218
2216
return SDValue ();
2219
2217
2220
- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2218
+ SDLoc DL (Op);
2219
+ const SDNodeFlags Flags = Op->getFlags ();
2220
+ const SDValue &Vector = Op.getOperand (0 );
2221
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2221
2222
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2222
2223
STI.getPTXVersion () >= 88 ;
2223
- SDLoc DL (Op);
2224
- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2224
+
2225
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2226
+ // number of inputs they take.
2227
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2228
+ bool IsReassociatable;
2229
+
2225
2230
switch (Op->getOpcode ()) {
2226
2231
case ISD::VECREDUCE_FADD:
2227
- Operators = {{ISD::FADD, 2 }};
2232
+ ScalarOps = {{ISD::FADD, 2 }};
2233
+ IsReassociatable = false ;
2228
2234
break ;
2229
2235
case ISD::VECREDUCE_FMUL:
2230
- Operators = {{ISD::FMUL, 2 }};
2236
+ ScalarOps = {{ISD::FMUL, 2 }};
2237
+ IsReassociatable = false ;
2231
2238
break ;
2232
2239
case ISD::VECREDUCE_FMAX:
2233
2240
if (CanUseMinMax3)
2234
- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2235
- Operators.push_back ({ISD::FMAXNUM, 2 });
2241
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2242
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2243
+ IsReassociatable = false ;
2236
2244
break ;
2237
2245
case ISD::VECREDUCE_FMIN:
2238
2246
if (CanUseMinMax3)
2239
- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2240
- Operators.push_back ({ISD::FMINNUM, 2 });
2247
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2248
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2249
+ IsReassociatable = false ;
2241
2250
break ;
2242
2251
case ISD::VECREDUCE_FMAXIMUM:
2243
2252
if (CanUseMinMax3)
2244
- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2245
- Operators.push_back ({ISD::FMAXIMUM, 2 });
2253
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2254
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2255
+ IsReassociatable = false ;
2246
2256
break ;
2247
2257
case ISD::VECREDUCE_FMINIMUM:
2248
2258
if (CanUseMinMax3)
2249
- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2250
- Operators.push_back ({ISD::FMINIMUM, 2 });
2259
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2260
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2261
+ IsReassociatable = false ;
2262
+ break ;
2263
+ case ISD::VECREDUCE_ADD:
2264
+ ScalarOps = {{ISD::ADD, 2 }};
2265
+ IsReassociatable = true ;
2266
+ break ;
2267
+ case ISD::VECREDUCE_MUL:
2268
+ ScalarOps = {{ISD::MUL, 2 }};
2269
+ IsReassociatable = true ;
2270
+ break ;
2271
+ case ISD::VECREDUCE_UMAX:
2272
+ ScalarOps = {{ISD::UMAX, 2 }};
2273
+ IsReassociatable = true ;
2274
+ break ;
2275
+ case ISD::VECREDUCE_UMIN:
2276
+ ScalarOps = {{ISD::UMIN, 2 }};
2277
+ IsReassociatable = true ;
2278
+ break ;
2279
+ case ISD::VECREDUCE_SMAX:
2280
+ ScalarOps = {{ISD::SMAX, 2 }};
2281
+ IsReassociatable = true ;
2282
+ break ;
2283
+ case ISD::VECREDUCE_SMIN:
2284
+ ScalarOps = {{ISD::SMIN, 2 }};
2285
+ IsReassociatable = true ;
2286
+ break ;
2287
+ case ISD::VECREDUCE_AND:
2288
+ ScalarOps = {{ISD::AND, 2 }};
2289
+ IsReassociatable = true ;
2290
+ break ;
2291
+ case ISD::VECREDUCE_OR:
2292
+ ScalarOps = {{ISD::OR, 2 }};
2293
+ IsReassociatable = true ;
2294
+ break ;
2295
+ case ISD::VECREDUCE_XOR:
2296
+ ScalarOps = {{ISD::XOR, 2 }};
2297
+ IsReassociatable = true ;
2251
2298
break ;
2252
2299
default :
2253
2300
llvm_unreachable (" unhandled vecreduce operation" );
2254
2301
}
2255
2302
2256
- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2257
- DAG);
2303
+ EVT VectorTy = Vector.getValueType ();
2304
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2305
+
2306
+ // scalarize vector
2307
+ SmallVector<SDValue> Elements (NumElts);
2308
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2309
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2310
+ DAG.getConstant (I, DL, MVT::i64 ));
2311
+ }
2312
+
2313
+ // Lower to tree reduction.
2314
+ if (IsReassociatable || Flags.hasAllowReassociation ())
2315
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2316
+
2317
+ // Lower to sequential reduction.
2318
+ SDValue Accumulator;
2319
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2320
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2321
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2322
+
2323
+ if (!Accumulator) {
2324
+ if (I + DefaultGroupSize <= NumElts) {
2325
+ Accumulator = DAG.getNode (
2326
+ DefaultScalarOp, DL, EltTy,
2327
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2328
+ I += DefaultGroupSize;
2329
+ }
2330
+ }
2331
+
2332
+ if (Accumulator) {
2333
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2334
+ SmallVector<SDValue> Operands = {Accumulator};
2335
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2336
+ Operands.push_back (Elements[I + K]);
2337
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2338
+ }
2339
+ }
2340
+ }
2341
+
2342
+ return Accumulator;
2258
2343
}
2259
2344
2260
2345
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3006,6 +3091,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3006
3091
case ISD::VECREDUCE_FMIN:
3007
3092
case ISD::VECREDUCE_FMAXIMUM:
3008
3093
case ISD::VECREDUCE_FMINIMUM:
3094
+ case ISD::VECREDUCE_ADD:
3095
+ case ISD::VECREDUCE_MUL:
3096
+ case ISD::VECREDUCE_UMAX:
3097
+ case ISD::VECREDUCE_UMIN:
3098
+ case ISD::VECREDUCE_SMAX:
3099
+ case ISD::VECREDUCE_SMIN:
3100
+ case ISD::VECREDUCE_AND:
3101
+ case ISD::VECREDUCE_OR:
3102
+ case ISD::VECREDUCE_XOR:
3009
3103
return LowerVECREDUCE (Op, DAG);
3010
3104
case ISD::STORE:
3011
3105
return LowerSTORE (Op, DAG);
0 commit comments