@@ -835,12 +835,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
835
835
setTargetDAGCombine (ISD::SETCC);
836
836
837
837
// Vector reduction operations. These are transformed into a tree evaluation
838
- // of nodes which may or may not be legal .
838
+ // of nodes which may initially be illegal .
839
839
for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
840
- setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841
- ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842
- ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843
- VT, Custom);
840
+ MVT EltVT = VT.getVectorElementType ();
841
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842
+ EltVT == MVT::f64 ) {
843
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846
+ VT, Custom);
847
+ } else if (EltVT.isScalarInteger ()) {
848
+ setOperationAction (
849
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
850
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
851
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
852
+ VT, Custom);
853
+ }
844
854
}
845
855
846
856
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2147,29 +2157,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2147
2157
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2148
2158
}
2149
2159
2150
- // / A generic routine for constructing a tree reduction for a vector operand.
2160
+ // / A generic routine for constructing a tree reduction on a vector operand.
2151
2161
// / This method differs from iterative splitting in DAGTypeLegalizer by
2152
- // / first scalarizing the vector and then progressively grouping elements
2153
- // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2154
- // / with different numbers of operands (eg. max3 vs max2).
2162
+ // / progressively grouping elements bottom-up.
2155
2163
static SDValue BuildTreeReduction (
2156
- const SDValue &VectorOp ,
2164
+ const SmallVector< SDValue> &Elements, EVT EltTy ,
2157
2165
ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2158
2166
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2159
- EVT VectorTy = VectorOp.getValueType ();
2160
- EVT EltTy = VectorTy.getVectorElementType ();
2161
- const unsigned NumElts = VectorTy.getVectorNumElements ();
2162
-
2163
- // scalarize vector
2164
- SmallVector<SDValue> Elements (NumElts);
2165
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2166
- Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2167
- DAG.getConstant (I, DL, MVT::i64 ));
2168
- }
2169
-
2170
2167
// now build the computation graph in place at each level
2171
2168
SmallVector<SDValue> Level = Elements;
2172
- for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2169
+ unsigned OpIdx = 0 ;
2170
+ while (Level.size () > 1 ) {
2173
2171
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2174
2172
2175
2173
// partially reduce all elements in level
@@ -2201,52 +2199,139 @@ static SDValue BuildTreeReduction(
2201
2199
return *Level.begin ();
2202
2200
}
2203
2201
2204
- // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2205
- // / serializes it.
2202
+ // / Lower reductions to either a sequence of operations or a tree if
2203
+ // / reassociations are allowed. This method will use larger operations like
2204
+ // / max3/min3 when the target supports them.
2206
2205
SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2207
2206
SelectionDAG &DAG) const {
2208
- // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2209
- if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2207
+ if (DisableFOpTreeReduce)
2210
2208
return SDValue ();
2211
2209
2212
- EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2210
+ SDLoc DL (Op);
2211
+ const SDNodeFlags Flags = Op->getFlags ();
2212
+ const SDValue &Vector = Op.getOperand (0 );
2213
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2213
2214
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2214
2215
STI.getPTXVersion () >= 88 ;
2215
- SDLoc DL (Op);
2216
- SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2216
+
2217
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2218
+ // number of inputs they take.
2219
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2220
+ bool IsReassociatable;
2221
+
2217
2222
switch (Op->getOpcode ()) {
2218
2223
case ISD::VECREDUCE_FADD:
2219
- Operators = {{ISD::FADD, 2 }};
2224
+ ScalarOps = {{ISD::FADD, 2 }};
2225
+ IsReassociatable = false ;
2220
2226
break ;
2221
2227
case ISD::VECREDUCE_FMUL:
2222
- Operators = {{ISD::FMUL, 2 }};
2228
+ ScalarOps = {{ISD::FMUL, 2 }};
2229
+ IsReassociatable = false ;
2223
2230
break ;
2224
2231
case ISD::VECREDUCE_FMAX:
2225
2232
if (CanUseMinMax3)
2226
- Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2227
- Operators.push_back ({ISD::FMAXNUM, 2 });
2233
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2234
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2235
+ IsReassociatable = false ;
2228
2236
break ;
2229
2237
case ISD::VECREDUCE_FMIN:
2230
2238
if (CanUseMinMax3)
2231
- Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2232
- Operators.push_back ({ISD::FMINNUM, 2 });
2239
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2240
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2241
+ IsReassociatable = false ;
2233
2242
break ;
2234
2243
case ISD::VECREDUCE_FMAXIMUM:
2235
2244
if (CanUseMinMax3)
2236
- Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2237
- Operators.push_back ({ISD::FMAXIMUM, 2 });
2245
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2246
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2247
+ IsReassociatable = false ;
2238
2248
break ;
2239
2249
case ISD::VECREDUCE_FMINIMUM:
2240
2250
if (CanUseMinMax3)
2241
- Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2242
- Operators.push_back ({ISD::FMINIMUM, 2 });
2251
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2252
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2253
+ IsReassociatable = false ;
2254
+ break ;
2255
+ case ISD::VECREDUCE_ADD:
2256
+ ScalarOps = {{ISD::ADD, 2 }};
2257
+ IsReassociatable = true ;
2258
+ break ;
2259
+ case ISD::VECREDUCE_MUL:
2260
+ ScalarOps = {{ISD::MUL, 2 }};
2261
+ IsReassociatable = true ;
2262
+ break ;
2263
+ case ISD::VECREDUCE_UMAX:
2264
+ ScalarOps = {{ISD::UMAX, 2 }};
2265
+ IsReassociatable = true ;
2266
+ break ;
2267
+ case ISD::VECREDUCE_UMIN:
2268
+ ScalarOps = {{ISD::UMIN, 2 }};
2269
+ IsReassociatable = true ;
2270
+ break ;
2271
+ case ISD::VECREDUCE_SMAX:
2272
+ ScalarOps = {{ISD::SMAX, 2 }};
2273
+ IsReassociatable = true ;
2274
+ break ;
2275
+ case ISD::VECREDUCE_SMIN:
2276
+ ScalarOps = {{ISD::SMIN, 2 }};
2277
+ IsReassociatable = true ;
2278
+ break ;
2279
+ case ISD::VECREDUCE_AND:
2280
+ ScalarOps = {{ISD::AND, 2 }};
2281
+ IsReassociatable = true ;
2282
+ break ;
2283
+ case ISD::VECREDUCE_OR:
2284
+ ScalarOps = {{ISD::OR, 2 }};
2285
+ IsReassociatable = true ;
2286
+ break ;
2287
+ case ISD::VECREDUCE_XOR:
2288
+ ScalarOps = {{ISD::XOR, 2 }};
2289
+ IsReassociatable = true ;
2243
2290
break ;
2244
2291
default :
2245
2292
llvm_unreachable (" unhandled vecreduce operation" );
2246
2293
}
2247
2294
2248
- return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2249
- DAG);
2295
+ EVT VectorTy = Vector.getValueType ();
2296
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2297
+
2298
+ // scalarize vector
2299
+ SmallVector<SDValue> Elements (NumElts);
2300
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2301
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2302
+ DAG.getConstant (I, DL, MVT::i64 ));
2303
+ }
2304
+
2305
+ // Lower to tree reduction.
2306
+ if (IsReassociatable || Flags.hasAllowReassociation ())
2307
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2308
+
2309
+ // Lower to sequential reduction.
2310
+ SDValue Accumulator;
2311
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2312
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2313
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2314
+
2315
+ if (!Accumulator) {
2316
+ if (I + DefaultGroupSize <= NumElts) {
2317
+ Accumulator = DAG.getNode (
2318
+ DefaultScalarOp, DL, EltTy,
2319
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2320
+ I += DefaultGroupSize;
2321
+ }
2322
+ }
2323
+
2324
+ if (Accumulator) {
2325
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2326
+ SmallVector<SDValue> Operands = {Accumulator};
2327
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2328
+ Operands.push_back (Elements[I + K]);
2329
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2330
+ }
2331
+ }
2332
+ }
2333
+
2334
+ return Accumulator;
2250
2335
}
2251
2336
2252
2337
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
@@ -3032,6 +3117,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
3032
3117
case ISD::VECREDUCE_FMIN:
3033
3118
case ISD::VECREDUCE_FMAXIMUM:
3034
3119
case ISD::VECREDUCE_FMINIMUM:
3120
+ case ISD::VECREDUCE_ADD:
3121
+ case ISD::VECREDUCE_MUL:
3122
+ case ISD::VECREDUCE_UMAX:
3123
+ case ISD::VECREDUCE_UMIN:
3124
+ case ISD::VECREDUCE_SMAX:
3125
+ case ISD::VECREDUCE_SMIN:
3126
+ case ISD::VECREDUCE_AND:
3127
+ case ISD::VECREDUCE_OR:
3128
+ case ISD::VECREDUCE_XOR:
3035
3129
return LowerVECREDUCE (Op, DAG);
3036
3130
case ISD::STORE:
3037
3131
return LowerSTORE (Op, DAG);
0 commit comments