Skip to content

Commit 7f5440b

Browse files
committed
[NVPTX] support rest of VECREDUCE intrinsics and other improvements
- Support all VECREDUCE intrinsics - Clean up FileCheck directives in lit test - Also handle sequential lowering in NVPTX backend, where we can still use larger operations.
1 parent 0b53293 commit 7f5440b

File tree

2 files changed

+1718
-410
lines changed

2 files changed

+1718
-410
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -838,12 +838,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
838838
setTargetDAGCombine(ISD::SETCC);
839839

840840
// Vector reduction operations. These are transformed into a tree evaluation
841-
// of nodes which may or may not be legal.
841+
// of nodes which may initially be illegal.
842842
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
843-
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844-
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845-
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846-
VT, Custom);
843+
MVT EltVT = VT.getVectorElementType();
844+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
845+
EltVT == MVT::f64) {
846+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849+
VT, Custom);
850+
} else if (EltVT.isScalarInteger()) {
851+
setOperationAction(
852+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
853+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
854+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
855+
VT, Custom);
856+
}
847857
}
848858

849859
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2155,29 +2165,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21552165
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21562166
}
21572167

2158-
/// A generic routine for constructing a tree reduction for a vector operand.
2168+
/// A generic routine for constructing a tree reduction on a vector operand.
21592169
/// This method differs from iterative splitting in DAGTypeLegalizer by
2160-
/// first scalarizing the vector and then progressively grouping elements
2161-
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2162-
/// with different numbers of operands (eg. max3 vs max2).
2170+
/// progressively grouping elements bottom-up.
21632171
static SDValue BuildTreeReduction(
2164-
const SDValue &VectorOp,
2172+
const SmallVector<SDValue> &Elements, EVT EltTy,
21652173
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21662174
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2167-
EVT VectorTy = VectorOp.getValueType();
2168-
EVT EltTy = VectorTy.getVectorElementType();
2169-
const unsigned NumElts = VectorTy.getVectorNumElements();
2170-
2171-
// scalarize vector
2172-
SmallVector<SDValue> Elements(NumElts);
2173-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2174-
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2175-
DAG.getConstant(I, DL, MVT::i64));
2176-
}
2177-
21782175
// now build the computation graph in place at each level
21792176
SmallVector<SDValue> Level = Elements;
2180-
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2177+
unsigned OpIdx = 0;
2178+
while (Level.size() > 1) {
21812179
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21822180

21832181
// partially reduce all elements in level
@@ -2209,52 +2207,139 @@ static SDValue BuildTreeReduction(
22092207
return *Level.begin();
22102208
}
22112209

2212-
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2213-
/// serializes it.
2210+
/// Lower reductions to either a sequence of operations or a tree if
2211+
/// reassociations are allowed. This method will use larger operations like
2212+
/// max3/min3 when the target supports them.
22142213
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22152214
SelectionDAG &DAG) const {
2216-
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2217-
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2215+
if (DisableFOpTreeReduce)
22182216
return SDValue();
22192217

2220-
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2218+
SDLoc DL(Op);
2219+
const SDNodeFlags Flags = Op->getFlags();
2220+
const SDValue &Vector = Op.getOperand(0);
2221+
EVT EltTy = Vector.getValueType().getVectorElementType();
22212222
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22222223
STI.getPTXVersion() >= 88;
2223-
SDLoc DL(Op);
2224-
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2224+
2225+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2226+
// number of inputs they take.
2227+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2228+
bool IsReassociatable;
2229+
22252230
switch (Op->getOpcode()) {
22262231
case ISD::VECREDUCE_FADD:
2227-
Operators = {{ISD::FADD, 2}};
2232+
ScalarOps = {{ISD::FADD, 2}};
2233+
IsReassociatable = false;
22282234
break;
22292235
case ISD::VECREDUCE_FMUL:
2230-
Operators = {{ISD::FMUL, 2}};
2236+
ScalarOps = {{ISD::FMUL, 2}};
2237+
IsReassociatable = false;
22312238
break;
22322239
case ISD::VECREDUCE_FMAX:
22332240
if (CanUseMinMax3)
2234-
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2235-
Operators.push_back({ISD::FMAXNUM, 2});
2241+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2242+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2243+
IsReassociatable = false;
22362244
break;
22372245
case ISD::VECREDUCE_FMIN:
22382246
if (CanUseMinMax3)
2239-
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2240-
Operators.push_back({ISD::FMINNUM, 2});
2247+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2248+
ScalarOps.push_back({ISD::FMINNUM, 2});
2249+
IsReassociatable = false;
22412250
break;
22422251
case ISD::VECREDUCE_FMAXIMUM:
22432252
if (CanUseMinMax3)
2244-
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2245-
Operators.push_back({ISD::FMAXIMUM, 2});
2253+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2254+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2255+
IsReassociatable = false;
22462256
break;
22472257
case ISD::VECREDUCE_FMINIMUM:
22482258
if (CanUseMinMax3)
2249-
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2250-
Operators.push_back({ISD::FMINIMUM, 2});
2259+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2260+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2261+
IsReassociatable = false;
2262+
break;
2263+
case ISD::VECREDUCE_ADD:
2264+
ScalarOps = {{ISD::ADD, 2}};
2265+
IsReassociatable = true;
2266+
break;
2267+
case ISD::VECREDUCE_MUL:
2268+
ScalarOps = {{ISD::MUL, 2}};
2269+
IsReassociatable = true;
2270+
break;
2271+
case ISD::VECREDUCE_UMAX:
2272+
ScalarOps = {{ISD::UMAX, 2}};
2273+
IsReassociatable = true;
2274+
break;
2275+
case ISD::VECREDUCE_UMIN:
2276+
ScalarOps = {{ISD::UMIN, 2}};
2277+
IsReassociatable = true;
2278+
break;
2279+
case ISD::VECREDUCE_SMAX:
2280+
ScalarOps = {{ISD::SMAX, 2}};
2281+
IsReassociatable = true;
2282+
break;
2283+
case ISD::VECREDUCE_SMIN:
2284+
ScalarOps = {{ISD::SMIN, 2}};
2285+
IsReassociatable = true;
2286+
break;
2287+
case ISD::VECREDUCE_AND:
2288+
ScalarOps = {{ISD::AND, 2}};
2289+
IsReassociatable = true;
2290+
break;
2291+
case ISD::VECREDUCE_OR:
2292+
ScalarOps = {{ISD::OR, 2}};
2293+
IsReassociatable = true;
2294+
break;
2295+
case ISD::VECREDUCE_XOR:
2296+
ScalarOps = {{ISD::XOR, 2}};
2297+
IsReassociatable = true;
22512298
break;
22522299
default:
22532300
llvm_unreachable("unhandled vecreduce operation");
22542301
}
22552302

2256-
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2257-
DAG);
2303+
EVT VectorTy = Vector.getValueType();
2304+
const unsigned NumElts = VectorTy.getVectorNumElements();
2305+
2306+
// scalarize vector
2307+
SmallVector<SDValue> Elements(NumElts);
2308+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2309+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2310+
DAG.getConstant(I, DL, MVT::i64));
2311+
}
2312+
2313+
// Lower to tree reduction.
2314+
if (IsReassociatable || Flags.hasAllowReassociation())
2315+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2316+
2317+
// Lower to sequential reduction.
2318+
SDValue Accumulator;
2319+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2320+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2321+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2322+
2323+
if (!Accumulator) {
2324+
if (I + DefaultGroupSize <= NumElts) {
2325+
Accumulator = DAG.getNode(
2326+
DefaultScalarOp, DL, EltTy,
2327+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2328+
I += DefaultGroupSize;
2329+
}
2330+
}
2331+
2332+
if (Accumulator) {
2333+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2334+
SmallVector<SDValue> Operands = {Accumulator};
2335+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2336+
Operands.push_back(Elements[I + K]);
2337+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2338+
}
2339+
}
2340+
}
2341+
2342+
return Accumulator;
22582343
}
22592344

22602345
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
@@ -3006,6 +3091,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30063091
case ISD::VECREDUCE_FMIN:
30073092
case ISD::VECREDUCE_FMAXIMUM:
30083093
case ISD::VECREDUCE_FMINIMUM:
3094+
case ISD::VECREDUCE_ADD:
3095+
case ISD::VECREDUCE_MUL:
3096+
case ISD::VECREDUCE_UMAX:
3097+
case ISD::VECREDUCE_UMIN:
3098+
case ISD::VECREDUCE_SMAX:
3099+
case ISD::VECREDUCE_SMIN:
3100+
case ISD::VECREDUCE_AND:
3101+
case ISD::VECREDUCE_OR:
3102+
case ISD::VECREDUCE_XOR:
30093103
return LowerVECREDUCE(Op, DAG);
30103104
case ISD::STORE:
30113105
return LowerSTORE(Op, DAG);

0 commit comments

Comments
 (0)