@@ -835,12 +835,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
835
835
setTargetDAGCombine (ISD::SETCC);
836
836
837
837
// Vector reduction operations. These are transformed into a tree evaluation
838
- // of nodes which may or may not be legal .
838
+ // of nodes which may initially be illegal .
839
839
for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
840
- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841
- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842
- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843
- VT, Custom);
840
+ MVT EltVT = VT.getVectorElementType ();
841
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842
+ EltVT == MVT::f64 ) {
843
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846
+ VT, Custom);
847
+ } else if (EltVT.isScalarInteger ()) {
848
+ setOperationAction (
849
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
850
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
851
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
852
+ VT, Custom);
853
+ }
844
854
}
845
855
846
856
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2160,29 +2170,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2160
2170
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2161
2171
}
2162
2172
2163
- // / A generic routine for constructing a tree reduction for a vector operand.
2173
+ // / A generic routine for constructing a tree reduction on a vector operand.
2164
2174
// / This method differs from iterative splitting in DAGTypeLegalizer by
2165
- // / first scalarizing the vector and then progressively grouping elements
2166
- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2167
- // / with different numbers of operands (eg. max3 vs max2).
2175
+ // / progressively grouping elements bottom-up.
2168
2176
static SDValue BuildTreeReduction (
2169
- const SDValue &VectorOp ,
2177
+ const SmallVector< SDValue> &Elements, EVT EltTy ,
2170
2178
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2171
2179
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2172
- EVT VectorTy = VectorOp.getValueType ();
2173
- EVT EltTy = VectorTy.getVectorElementType ();
2174
- const unsigned NumElts = VectorTy.getVectorNumElements ();
2175
-
2176
- // scalarize vector
2177
- SmallVector<SDValue> Elements (NumElts);
2178
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2179
- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2180
- DAG.getConstant (I, DL, MVT::i64 ));
2181
- }
2182
-
2183
2180
// now build the computation graph in place at each level
2184
2181
SmallVector<SDValue> Level = Elements;
2185
- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2182
+ unsigned OpIdx = 0 ;
2183
+ while (Level.size () > 1 ) {
2186
2184
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2187
2185
2188
2186
// partially reduce all elements in level
@@ -2214,52 +2212,139 @@ static SDValue BuildTreeReduction(
2214
2212
return *Level.begin ();
2215
2213
}
2216
2214
2217
- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2218
- // / serializes it.
2215
+ // / Lower reductions to either a sequence of operations or a tree if
2216
+ // / reassociations are allowed. This method will use larger operations like
2217
+ // / max3/min3 when the target supports them.
2219
2218
SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2220
2219
SelectionDAG &DAG) const {
2221
- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2222
- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2220
+ if (DisableFOpTreeReduce)
2223
2221
return SDValue ();
2224
2222
2225
- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2223
+ SDLoc DL (Op);
2224
+ const SDNodeFlags Flags = Op->getFlags ();
2225
+ const SDValue &Vector = Op.getOperand (0 );
2226
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2226
2227
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2227
2228
STI.getPTXVersion () >= 88 ;
2228
- SDLoc DL (Op);
2229
- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2229
+
2230
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2231
+ // number of inputs they take.
2232
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2233
+ bool IsReassociatable;
2234
+
2230
2235
switch (Op->getOpcode ()) {
2231
2236
case ISD::VECREDUCE_FADD:
2232
- Operators = {{ISD::FADD, 2 }};
2237
+ ScalarOps = {{ISD::FADD, 2 }};
2238
+ IsReassociatable = false ;
2233
2239
break ;
2234
2240
case ISD::VECREDUCE_FMUL:
2235
- Operators = {{ISD::FMUL, 2 }};
2241
+ ScalarOps = {{ISD::FMUL, 2 }};
2242
+ IsReassociatable = false ;
2236
2243
break ;
2237
2244
case ISD::VECREDUCE_FMAX:
2238
2245
if (CanUseMinMax3)
2239
- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2240
- Operators.push_back ({ISD::FMAXNUM, 2 });
2246
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2247
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2248
+ IsReassociatable = false ;
2241
2249
break ;
2242
2250
case ISD::VECREDUCE_FMIN:
2243
2251
if (CanUseMinMax3)
2244
- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2245
- Operators.push_back ({ISD::FMINNUM, 2 });
2252
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2253
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2254
+ IsReassociatable = false ;
2246
2255
break ;
2247
2256
case ISD::VECREDUCE_FMAXIMUM:
2248
2257
if (CanUseMinMax3)
2249
- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2250
- Operators.push_back ({ISD::FMAXIMUM, 2 });
2258
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2259
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2260
+ IsReassociatable = false ;
2251
2261
break ;
2252
2262
case ISD::VECREDUCE_FMINIMUM:
2253
2263
if (CanUseMinMax3)
2254
- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2255
- Operators.push_back ({ISD::FMINIMUM, 2 });
2264
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2265
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2266
+ IsReassociatable = false ;
2267
+ break ;
2268
+ case ISD::VECREDUCE_ADD:
2269
+ ScalarOps = {{ISD::ADD, 2 }};
2270
+ IsReassociatable = true ;
2271
+ break ;
2272
+ case ISD::VECREDUCE_MUL:
2273
+ ScalarOps = {{ISD::MUL, 2 }};
2274
+ IsReassociatable = true ;
2275
+ break ;
2276
+ case ISD::VECREDUCE_UMAX:
2277
+ ScalarOps = {{ISD::UMAX, 2 }};
2278
+ IsReassociatable = true ;
2279
+ break ;
2280
+ case ISD::VECREDUCE_UMIN:
2281
+ ScalarOps = {{ISD::UMIN, 2 }};
2282
+ IsReassociatable = true ;
2283
+ break ;
2284
+ case ISD::VECREDUCE_SMAX:
2285
+ ScalarOps = {{ISD::SMAX, 2 }};
2286
+ IsReassociatable = true ;
2287
+ break ;
2288
+ case ISD::VECREDUCE_SMIN:
2289
+ ScalarOps = {{ISD::SMIN, 2 }};
2290
+ IsReassociatable = true ;
2291
+ break ;
2292
+ case ISD::VECREDUCE_AND:
2293
+ ScalarOps = {{ISD::AND, 2 }};
2294
+ IsReassociatable = true ;
2295
+ break ;
2296
+ case ISD::VECREDUCE_OR:
2297
+ ScalarOps = {{ISD::OR, 2 }};
2298
+ IsReassociatable = true ;
2299
+ break ;
2300
+ case ISD::VECREDUCE_XOR:
2301
+ ScalarOps = {{ISD::XOR, 2 }};
2302
+ IsReassociatable = true ;
2256
2303
break ;
2257
2304
default :
2258
2305
llvm_unreachable (" unhandled vecreduce operation" );
2259
2306
}
2260
2307
2261
- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2262
- DAG);
2308
+ EVT VectorTy = Vector.getValueType ();
2309
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2310
+
2311
+ // scalarize vector
2312
+ SmallVector<SDValue> Elements (NumElts);
2313
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2314
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2315
+ DAG.getConstant (I, DL, MVT::i64 ));
2316
+ }
2317
+
2318
+ // Lower to tree reduction.
2319
+ if (IsReassociatable || Flags.hasAllowReassociation ())
2320
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2321
+
2322
+ // Lower to sequential reduction.
2323
+ SDValue Accumulator;
2324
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2325
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2326
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2327
+
2328
+ if (!Accumulator) {
2329
+ if (I + DefaultGroupSize <= NumElts) {
2330
+ Accumulator = DAG.getNode (
2331
+ DefaultScalarOp, DL, EltTy,
2332
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2333
+ I += DefaultGroupSize;
2334
+ }
2335
+ }
2336
+
2337
+ if (Accumulator) {
2338
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2339
+ SmallVector<SDValue> Operands = {Accumulator};
2340
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2341
+ Operands.push_back (Elements[I + K]);
2342
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2343
+ }
2344
+ }
2345
+ }
2346
+
2347
+ return Accumulator;
2263
2348
}
2264
2349
2265
2350
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3056,6 +3141,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3056
3141
case ISD::VECREDUCE_FMIN:
3057
3142
case ISD::VECREDUCE_FMAXIMUM:
3058
3143
case ISD::VECREDUCE_FMINIMUM:
3144
+ case ISD::VECREDUCE_ADD:
3145
+ case ISD::VECREDUCE_MUL:
3146
+ case ISD::VECREDUCE_UMAX:
3147
+ case ISD::VECREDUCE_UMIN:
3148
+ case ISD::VECREDUCE_SMAX:
3149
+ case ISD::VECREDUCE_SMIN:
3150
+ case ISD::VECREDUCE_AND:
3151
+ case ISD::VECREDUCE_OR:
3152
+ case ISD::VECREDUCE_XOR:
3059
3153
return LowerVECREDUCE (Op, DAG);
3060
3154
case ISD::STORE:
3061
3155
return LowerSTORE (Op, DAG);
0 commit comments