@@ -85,6 +85,12 @@ static cl::opt<unsigned> FMAContractLevelOpt(
85
85
" 1: do it 2: do it aggressively" ),
86
86
cl::init(2 ));
87
87
88
+ static cl::opt<bool > DisableFOpTreeReduce (
89
+ " nvptx-disable-fop-tree-reduce" , cl::Hidden,
90
+ cl::desc (" NVPTX Specific: don't emit tree reduction for floating-point "
91
+ " reduction operations" ),
92
+ cl::init(false ));
93
+
88
94
static cl::opt<int > UsePrecDivF32 (
89
95
" nvptx-prec-divf32" , cl::Hidden,
90
96
cl::desc (" NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -828,6 +834,15 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828
834
if (STI.allowFP16Math () || STI.hasBF16Math ())
829
835
setTargetDAGCombine (ISD::SETCC);
830
836
837
+ // Vector reduction operations. These are transformed into a tree evaluation
838
+ // of nodes which may or may not be legal.
839
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
840
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843
+ VT, Custom);
844
+ }
845
+
831
846
// Promote fp16 arithmetic if fp16 hardware isn't available or the
832
847
// user passed --nvptx-no-fp16-math. The flag is useful because,
833
848
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1081,6 +1096,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1081
1096
MAKE_CASE (NVPTXISD::BFI)
1082
1097
MAKE_CASE (NVPTXISD::PRMT)
1083
1098
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1099
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1100
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1101
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1102
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1084
1103
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1085
1104
MAKE_CASE (NVPTXISD::STACKRESTORE)
1086
1105
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2141,6 +2160,108 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2141
2160
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2142
2161
}
2143
2162
2163
+ // / A generic routine for constructing a tree reduction for a vector operand.
2164
+ // / This method differs from iterative splitting in DAGTypeLegalizer by
2165
+ // / first scalarizing the vector and then progressively grouping elements
2166
+ // / bottom-up. This allows easily building the optimal (minimum) number of nodes
2167
+ // / with different numbers of operands (eg. max3 vs max2).
2168
+ static SDValue BuildTreeReduction (
2169
+ const SDValue &VectorOp,
2170
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2171
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2172
+ EVT VectorTy = VectorOp.getValueType ();
2173
+ EVT EltTy = VectorTy.getVectorElementType ();
2174
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2175
+
2176
+ // scalarize vector
2177
+ SmallVector<SDValue> Elements (NumElts);
2178
+ for (unsigned I = 0 , E = NumElts; I != E; ++I) {
2179
+ Elements[I] = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2180
+ DAG.getConstant (I, DL, MVT::i64 ));
2181
+ }
2182
+
2183
+ // now build the computation graph in place at each level
2184
+ SmallVector<SDValue> Level = Elements;
2185
+ for (unsigned OpIdx = 0 ; Level.size () > 1 && OpIdx < Ops.size ();) {
2186
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2187
+
2188
+ // partially reduce all elements in level
2189
+ SmallVector<SDValue> ReducedLevel;
2190
+ unsigned I = 0 , E = Level.size ();
2191
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2192
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2193
+ ReducedLevel.push_back (DAG.getNode (
2194
+ DefaultScalarOp, DL, EltTy,
2195
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2196
+ }
2197
+
2198
+ if (I < E) {
2199
+ if (ReducedLevel.empty ()) {
2200
+ // The current operator requires more inputs than there are operands at
2201
+ // this level. Pick a smaller operator and retry.
2202
+ ++OpIdx;
2203
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2204
+ continue ;
2205
+ }
2206
+
2207
+ // Otherwise, we just have a remainder, which we push to the next level.
2208
+ for (; I < E; ++I)
2209
+ ReducedLevel.push_back (Level[I]);
2210
+ }
2211
+ Level = ReducedLevel;
2212
+ }
2213
+
2214
+ return *Level.begin ();
2215
+ }
2216
+
2217
+ // / Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2218
+ // / serializes it.
2219
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2220
+ SelectionDAG &DAG) const {
2221
+ // If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2222
+ if (DisableFOpTreeReduce || !Op->getFlags ().hasAllowReassociation ())
2223
+ return SDValue ();
2224
+
2225
+ EVT EltTy = Op.getOperand (0 ).getValueType ().getVectorElementType ();
2226
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2227
+ STI.getPTXVersion () >= 88 ;
2228
+ SDLoc DL (Op);
2229
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > Operators;
2230
+ switch (Op->getOpcode ()) {
2231
+ case ISD::VECREDUCE_FADD:
2232
+ Operators = {{ISD::FADD, 2 }};
2233
+ break ;
2234
+ case ISD::VECREDUCE_FMUL:
2235
+ Operators = {{ISD::FMUL, 2 }};
2236
+ break ;
2237
+ case ISD::VECREDUCE_FMAX:
2238
+ if (CanUseMinMax3)
2239
+ Operators.push_back ({NVPTXISD::FMAXNUM3, 3 });
2240
+ Operators.push_back ({ISD::FMAXNUM, 2 });
2241
+ break ;
2242
+ case ISD::VECREDUCE_FMIN:
2243
+ if (CanUseMinMax3)
2244
+ Operators.push_back ({NVPTXISD::FMINNUM3, 3 });
2245
+ Operators.push_back ({ISD::FMINNUM, 2 });
2246
+ break ;
2247
+ case ISD::VECREDUCE_FMAXIMUM:
2248
+ if (CanUseMinMax3)
2249
+ Operators.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2250
+ Operators.push_back ({ISD::FMAXIMUM, 2 });
2251
+ break ;
2252
+ case ISD::VECREDUCE_FMINIMUM:
2253
+ if (CanUseMinMax3)
2254
+ Operators.push_back ({NVPTXISD::FMINIMUM3, 3 });
2255
+ Operators.push_back ({ISD::FMINIMUM, 2 });
2256
+ break ;
2257
+ default :
2258
+ llvm_unreachable (" unhandled vecreduce operation" );
2259
+ }
2260
+
2261
+ return BuildTreeReduction (Op.getOperand (0 ), Operators, DL, Op->getFlags (),
2262
+ DAG);
2263
+ }
2264
+
2144
2265
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2145
2266
// Handle bitcasting from v2i8 without hitting the default promotion
2146
2267
// strategy which goes through stack memory.
@@ -2929,6 +3050,13 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2929
3050
return LowerVECTOR_SHUFFLE (Op, DAG);
2930
3051
case ISD::CONCAT_VECTORS:
2931
3052
return LowerCONCAT_VECTORS (Op, DAG);
3053
+ case ISD::VECREDUCE_FADD:
3054
+ case ISD::VECREDUCE_FMUL:
3055
+ case ISD::VECREDUCE_FMAX:
3056
+ case ISD::VECREDUCE_FMIN:
3057
+ case ISD::VECREDUCE_FMAXIMUM:
3058
+ case ISD::VECREDUCE_FMINIMUM:
3059
+ return LowerVECREDUCE (Op, DAG);
2932
3060
case ISD::STORE:
2933
3061
return LowerSTORE (Op, DAG);
2934
3062
case ISD::LOAD:
0 commit comments