@@ -850,6 +850,27 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
850
850
if (STI.allowFP16Math () || STI.hasBF16Math ())
851
851
setTargetDAGCombine (ISD::SETCC);
852
852
853
+ // Vector reduction operations. These may be turned into sequential, shuffle,
854
+ // or tree reductions depending on what instructions are available for each
855
+ // type.
856
+ for (MVT VT : MVT::fixedlen_vector_valuetypes ()) {
857
+ MVT EltVT = VT.getVectorElementType ();
858
+ if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
859
+ EltVT == MVT::f64 ) {
860
+ setOperationAction ({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
861
+ ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
862
+ ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
863
+ ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
864
+ VT, Custom);
865
+ } else if (EltVT.isScalarInteger ()) {
866
+ setOperationAction (
867
+ {ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
868
+ ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
869
+ ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
870
+ VT, Custom);
871
+ }
872
+ }
873
+
853
874
// Promote fp16 arithmetic if fp16 hardware isn't available or the
854
875
// user passed --nvptx-no-fp16-math. The flag is useful because,
855
876
// although sm_53+ GPUs have some sort of FP16 support in
@@ -1083,6 +1104,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1083
1104
MAKE_CASE (NVPTXISD::BFI)
1084
1105
MAKE_CASE (NVPTXISD::PRMT)
1085
1106
MAKE_CASE (NVPTXISD::FCOPYSIGN)
1107
+ MAKE_CASE (NVPTXISD::FMAXNUM3)
1108
+ MAKE_CASE (NVPTXISD::FMINNUM3)
1109
+ MAKE_CASE (NVPTXISD::FMAXIMUM3)
1110
+ MAKE_CASE (NVPTXISD::FMINIMUM3)
1086
1111
MAKE_CASE (NVPTXISD::DYNAMIC_STACKALLOC)
1087
1112
MAKE_CASE (NVPTXISD::STACKRESTORE)
1088
1113
MAKE_CASE (NVPTXISD::STACKSAVE)
@@ -2038,6 +2063,259 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
2038
2063
return DAG.getBuildVector (Node->getValueType (0 ), dl, Ops);
2039
2064
}
2040
2065
2066
+ // / A generic routine for constructing a tree reduction on a vector operand.
2067
+ // / This method groups elements bottom-up, progressively building each level.
2068
+ // / Unlike the shuffle reduction used in DAGTypeLegalizer and ExpandReductions,
2069
+ // / adjacent elements are combined first, leading to shorter live ranges. This
2070
+ // / approach makes the most sense if the shuffle reduction would use the same
2071
+ // / amount of registers.
2072
+ // /
2073
+ // / The flags on the original reduction operation will be propagated to
2074
+ // / each scalar operation.
2075
+ static SDValue BuildTreeReduction (
2076
+ const SmallVector<SDValue> &Elements, EVT EltTy,
2077
+ ArrayRef<std::pair<unsigned /* NodeType*/ , unsigned /* NumInputs*/ >> Ops,
2078
+ const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2079
+ // Build the reduction tree at each level, starting with all the elements.
2080
+ SmallVector<SDValue> Level = Elements;
2081
+
2082
+ unsigned OpIdx = 0 ;
2083
+ while (Level.size () > 1 ) {
2084
+ // Try to reduce this level using the current operator.
2085
+ const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
2086
+
2087
+ // Build the next level by partially reducing all elements.
2088
+ SmallVector<SDValue> ReducedLevel;
2089
+ unsigned I = 0 , E = Level.size ();
2090
+ for (; I + DefaultGroupSize <= E; I += DefaultGroupSize) {
2091
+ // Reduce elements in groups of [DefaultGroupSize], as much as possible.
2092
+ ReducedLevel.push_back (DAG.getNode (
2093
+ DefaultScalarOp, DL, EltTy,
2094
+ ArrayRef<SDValue>(Level).slice (I, DefaultGroupSize), Flags));
2095
+ }
2096
+
2097
+ if (I < E) {
2098
+ // Handle leftover elements.
2099
+
2100
+ if (ReducedLevel.empty ()) {
2101
+ // We didn't reduce anything at this level. We need to pick a smaller
2102
+ // operator.
2103
+ ++OpIdx;
2104
+ assert (OpIdx < Ops.size () && " no smaller operators for reduction" );
2105
+ continue ;
2106
+ }
2107
+
2108
+ // We reduced some things but there's still more left, meaning the
2109
+ // operator's number of inputs doesn't evenly divide this level size. Move
2110
+ // these elements to the next level.
2111
+ for (; I < E; ++I)
2112
+ ReducedLevel.push_back (Level[I]);
2113
+ }
2114
+
2115
+ // Process the next level.
2116
+ Level = ReducedLevel;
2117
+ }
2118
+
2119
+ return *Level.begin ();
2120
+ }
2121
+
2122
+ // / Lower reductions to either a sequence of operations or a tree if
2123
+ // / reassociations are allowed. This method will use larger operations like
2124
+ // / max3/min3 when the target supports them.
2125
+ SDValue NVPTXTargetLowering::LowerVECREDUCE (SDValue Op,
2126
+ SelectionDAG &DAG) const {
2127
+ SDLoc DL (Op);
2128
+ const SDNodeFlags Flags = Op->getFlags ();
2129
+ SDValue Vector;
2130
+ SDValue Accumulator;
2131
+
2132
+ if (Op->getOpcode () == ISD::VECREDUCE_SEQ_FADD ||
2133
+ Op->getOpcode () == ISD::VECREDUCE_SEQ_FMUL) {
2134
+ // special case with accumulator as first arg
2135
+ Accumulator = Op.getOperand (0 );
2136
+ Vector = Op.getOperand (1 );
2137
+ } else {
2138
+ // default case
2139
+ Vector = Op.getOperand (0 );
2140
+ }
2141
+
2142
+ EVT EltTy = Vector.getValueType ().getVectorElementType ();
2143
+ const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion () >= 100 &&
2144
+ STI.getPTXVersion () >= 88 ;
2145
+
2146
+ // A list of SDNode opcodes with equivalent semantics, sorted descending by
2147
+ // number of inputs they take.
2148
+ SmallVector<std::pair<unsigned /* Op*/ , unsigned /* NumIn*/ >, 2 > ScalarOps;
2149
+
2150
+ // Whether we can lower to scalar operations in an arbitrary order.
2151
+ bool IsAssociative = allowUnsafeFPMath (DAG.getMachineFunction ());
2152
+
2153
+ // Whether the data type and operation can be represented with fewer ops and
2154
+ // registers in a shuffle reduction.
2155
+ bool PrefersShuffle;
2156
+
2157
+ switch (Op->getOpcode ()) {
2158
+ case ISD::VECREDUCE_FADD:
2159
+ case ISD::VECREDUCE_SEQ_FADD:
2160
+ ScalarOps = {{ISD::FADD, 2 }};
2161
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FADD;
2162
+ // Prefer add.{,b}f16x2 for v2{,b}f16
2163
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2164
+ break ;
2165
+ case ISD::VECREDUCE_FMUL:
2166
+ case ISD::VECREDUCE_SEQ_FMUL:
2167
+ ScalarOps = {{ISD::FMUL, 2 }};
2168
+ IsAssociative |= Op->getOpcode () == ISD::VECREDUCE_FMUL;
2169
+ // Prefer mul.{,b}f16x2 for v2{,b}f16
2170
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2171
+ break ;
2172
+ case ISD::VECREDUCE_FMAX:
2173
+ if (CanUseMinMax3)
2174
+ ScalarOps.push_back ({NVPTXISD::FMAXNUM3, 3 });
2175
+ ScalarOps.push_back ({ISD::FMAXNUM, 2 });
2176
+ // Definition of maxNum in IEEE 754 2008 is non-associative due to handling
2177
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2178
+ IsAssociative |= Flags.hasAllowReassociation ();
2179
+ PrefersShuffle = false ;
2180
+ break ;
2181
+ case ISD::VECREDUCE_FMIN:
2182
+ if (CanUseMinMax3)
2183
+ ScalarOps.push_back ({NVPTXISD::FMINNUM3, 3 });
2184
+ ScalarOps.push_back ({ISD::FMINNUM, 2 });
2185
+ // Definition of minNum in IEEE 754 2008 is non-associative due to handling
2186
+ // of sNaN inputs. Allow overriding with fast-math or 'reassoc' attribute.
2187
+ IsAssociative |= Flags.hasAllowReassociation ();
2188
+ PrefersShuffle = false ;
2189
+ break ;
2190
+ case ISD::VECREDUCE_FMAXIMUM:
2191
+ if (CanUseMinMax3) {
2192
+ ScalarOps.push_back ({NVPTXISD::FMAXIMUM3, 3 });
2193
+ // Can't use fmax3 in shuffle reduction
2194
+ PrefersShuffle = false ;
2195
+ } else {
2196
+ // Prefer max.{,b}f16x2 for v2{,b}f16
2197
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2198
+ }
2199
+ ScalarOps.push_back ({ISD::FMAXIMUM, 2 });
2200
+ IsAssociative = true ;
2201
+ break ;
2202
+ case ISD::VECREDUCE_FMINIMUM:
2203
+ if (CanUseMinMax3) {
2204
+ ScalarOps.push_back ({NVPTXISD::FMINIMUM3, 3 });
2205
+ // Can't use fmin3 in shuffle reduction
2206
+ PrefersShuffle = false ;
2207
+ } else {
2208
+ // Prefer min.{,b}f16x2 for v2{,b}f16
2209
+ PrefersShuffle = EltTy == MVT::f16 || EltTy == MVT::bf16 ;
2210
+ }
2211
+ ScalarOps.push_back ({ISD::FMINIMUM, 2 });
2212
+ IsAssociative = true ;
2213
+ break ;
2214
+ case ISD::VECREDUCE_ADD:
2215
+ ScalarOps = {{ISD::ADD, 2 }};
2216
+ IsAssociative = true ;
2217
+ // Prefer add.{s,u}16x2 for v2i16
2218
+ PrefersShuffle = EltTy == MVT::i16 ;
2219
+ break ;
2220
+ case ISD::VECREDUCE_MUL:
2221
+ ScalarOps = {{ISD::MUL, 2 }};
2222
+ IsAssociative = true ;
2223
+ // Integer multiply doesn't support packed types
2224
+ PrefersShuffle = false ;
2225
+ break ;
2226
+ case ISD::VECREDUCE_UMAX:
2227
+ ScalarOps = {{ISD::UMAX, 2 }};
2228
+ IsAssociative = true ;
2229
+ // Prefer max.u16x2 for v2i16
2230
+ PrefersShuffle = EltTy == MVT::i16 ;
2231
+ break ;
2232
+ case ISD::VECREDUCE_UMIN:
2233
+ ScalarOps = {{ISD::UMIN, 2 }};
2234
+ IsAssociative = true ;
2235
+ // Prefer min.u16x2 for v2i16
2236
+ PrefersShuffle = EltTy == MVT::i16 ;
2237
+ break ;
2238
+ case ISD::VECREDUCE_SMAX:
2239
+ ScalarOps = {{ISD::SMAX, 2 }};
2240
+ IsAssociative = true ;
2241
+ // Prefer max.s16x2 for v2i16
2242
+ PrefersShuffle = EltTy == MVT::i16 ;
2243
+ break ;
2244
+ case ISD::VECREDUCE_SMIN:
2245
+ ScalarOps = {{ISD::SMIN, 2 }};
2246
+ IsAssociative = true ;
2247
+ // Prefer min.s16x2 for v2i16
2248
+ PrefersShuffle = EltTy == MVT::i16 ;
2249
+ break ;
2250
+ case ISD::VECREDUCE_AND:
2251
+ ScalarOps = {{ISD::AND, 2 }};
2252
+ IsAssociative = true ;
2253
+ // Prefer and.b32 for v2i16.
2254
+ PrefersShuffle = EltTy == MVT::i16 ;
2255
+ break ;
2256
+ case ISD::VECREDUCE_OR:
2257
+ ScalarOps = {{ISD::OR, 2 }};
2258
+ IsAssociative = true ;
2259
+ // Prefer or.b32 for v2i16.
2260
+ PrefersShuffle = EltTy == MVT::i16 ;
2261
+ break ;
2262
+ case ISD::VECREDUCE_XOR:
2263
+ ScalarOps = {{ISD::XOR, 2 }};
2264
+ IsAssociative = true ;
2265
+ // Prefer xor.b32 for v2i16.
2266
+ PrefersShuffle = EltTy == MVT::i16 ;
2267
+ break ;
2268
+ default :
2269
+ llvm_unreachable (" unhandled vecreduce operation" );
2270
+ }
2271
+
2272
+ // We don't expect an accumulator for reassociative vector reduction ops.
2273
+ assert ((!IsAssociative || !Accumulator) && " unexpected accumulator" );
2274
+
2275
+ // If shuffle reduction is preferred, leave it to SelectionDAG.
2276
+ if (IsAssociative && PrefersShuffle)
2277
+ return SDValue ();
2278
+
2279
+ // Otherwise, handle the reduction here.
2280
+ SmallVector<SDValue> Elements;
2281
+ DAG.ExtractVectorElements (Vector, Elements);
2282
+
2283
+ // Lower to tree reduction.
2284
+ if (IsAssociative)
2285
+ return BuildTreeReduction (Elements, EltTy, ScalarOps, DL, Flags, DAG);
2286
+
2287
+ // Lower to sequential reduction.
2288
+ EVT VectorTy = Vector.getValueType ();
2289
+ const unsigned NumElts = VectorTy.getVectorNumElements ();
2290
+ for (unsigned OpIdx = 0 , I = 0 ; I < NumElts; ++OpIdx) {
2291
+ // Try to reduce the remaining sequence as much as possible using the
2292
+ // current operator.
2293
+ assert (OpIdx < ScalarOps.size () && " no smaller operators for reduction" );
2294
+ const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2295
+
2296
+ if (!Accumulator) {
2297
+ // Try to initialize the accumulator using the current operator.
2298
+ if (I + DefaultGroupSize <= NumElts) {
2299
+ Accumulator = DAG.getNode (
2300
+ DefaultScalarOp, DL, EltTy,
2301
+ ArrayRef (Elements).slice (I, I + DefaultGroupSize), Flags);
2302
+ I += DefaultGroupSize;
2303
+ }
2304
+ }
2305
+
2306
+ if (Accumulator) {
2307
+ for (; I + (DefaultGroupSize - 1 ) <= NumElts; I += DefaultGroupSize - 1 ) {
2308
+ SmallVector<SDValue> Operands = {Accumulator};
2309
+ for (unsigned K = 0 ; K < DefaultGroupSize - 1 ; ++K)
2310
+ Operands.push_back (Elements[I + K]);
2311
+ Accumulator = DAG.getNode (DefaultScalarOp, DL, EltTy, Operands, Flags);
2312
+ }
2313
+ }
2314
+ }
2315
+
2316
+ return Accumulator;
2317
+ }
2318
+
2041
2319
SDValue NVPTXTargetLowering::LowerBITCAST (SDValue Op, SelectionDAG &DAG) const {
2042
2320
// Handle bitcasting from v2i8 without hitting the default promotion
2043
2321
// strategy which goes through stack memory.
@@ -2869,6 +3147,24 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
2869
3147
return LowerVECTOR_SHUFFLE (Op, DAG);
2870
3148
case ISD::CONCAT_VECTORS:
2871
3149
return LowerCONCAT_VECTORS (Op, DAG);
3150
+ case ISD::VECREDUCE_FADD:
3151
+ case ISD::VECREDUCE_FMUL:
3152
+ case ISD::VECREDUCE_SEQ_FADD:
3153
+ case ISD::VECREDUCE_SEQ_FMUL:
3154
+ case ISD::VECREDUCE_FMAX:
3155
+ case ISD::VECREDUCE_FMIN:
3156
+ case ISD::VECREDUCE_FMAXIMUM:
3157
+ case ISD::VECREDUCE_FMINIMUM:
3158
+ case ISD::VECREDUCE_ADD:
3159
+ case ISD::VECREDUCE_MUL:
3160
+ case ISD::VECREDUCE_UMAX:
3161
+ case ISD::VECREDUCE_UMIN:
3162
+ case ISD::VECREDUCE_SMAX:
3163
+ case ISD::VECREDUCE_SMIN:
3164
+ case ISD::VECREDUCE_AND:
3165
+ case ISD::VECREDUCE_OR:
3166
+ case ISD::VECREDUCE_XOR:
3167
+ return LowerVECREDUCE (Op, DAG);
2872
3168
case ISD::STORE:
2873
3169
return LowerSTORE (Op, DAG);
2874
3170
case ISD::LOAD:
0 commit comments