Skip to content

Commit ec897d0

Browse files
committed
[NVPTX] support rest of VECREDUCE intrinsics and other improvements
- Support all VECREDUCE intrinsics - Clean up FileCheck directives in lit test - Also handle sequential lowering in NVPTX backend, where we can still use larger operations.
1 parent 3900a90 commit ec897d0

File tree

2 files changed

+1665
-184
lines changed

2 files changed

+1665
-184
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -835,12 +835,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
835835
setTargetDAGCombine(ISD::SETCC);
836836

837837
// Vector reduction operations. These are transformed into a tree evaluation
838-
// of nodes which may or may not be legal.
838+
// of nodes which may initially be illegal.
839839
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
840-
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841-
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842-
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843-
VT, Custom);
840+
MVT EltVT = VT.getVectorElementType();
841+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842+
EltVT == MVT::f64) {
843+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846+
VT, Custom);
847+
} else if (EltVT.isScalarInteger()) {
848+
setOperationAction(
849+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
850+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
851+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
852+
VT, Custom);
853+
}
844854
}
845855

846856
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2160,29 +2170,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21602170
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21612171
}
21622172

2163-
/// A generic routine for constructing a tree reduction for a vector operand.
2173+
/// A generic routine for constructing a tree reduction on a vector operand.
21642174
/// This method differs from iterative splitting in DAGTypeLegalizer by
2165-
/// first scalarizing the vector and then progressively grouping elements
2166-
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2167-
/// with different numbers of operands (eg. max3 vs max2).
2175+
/// progressively grouping elements bottom-up.
21682176
static SDValue BuildTreeReduction(
2169-
const SDValue &VectorOp,
2177+
const SmallVector<SDValue> &Elements, EVT EltTy,
21702178
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21712179
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2172-
EVT VectorTy = VectorOp.getValueType();
2173-
EVT EltTy = VectorTy.getVectorElementType();
2174-
const unsigned NumElts = VectorTy.getVectorNumElements();
2175-
2176-
// scalarize vector
2177-
SmallVector<SDValue> Elements(NumElts);
2178-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2179-
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2180-
DAG.getConstant(I, DL, MVT::i64));
2181-
}
2182-
21832180
// now build the computation graph in place at each level
21842181
SmallVector<SDValue> Level = Elements;
2185-
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2182+
unsigned OpIdx = 0;
2183+
while (Level.size() > 1) {
21862184
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21872185

21882186
// partially reduce all elements in level
@@ -2214,52 +2212,139 @@ static SDValue BuildTreeReduction(
22142212
return *Level.begin();
22152213
}
22162214

2217-
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2218-
/// serializes it.
2215+
/// Lower reductions to either a sequence of operations or a tree if
2216+
/// reassociations are allowed. This method will use larger operations like
2217+
/// max3/min3 when the target supports them.
22192218
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22202219
SelectionDAG &DAG) const {
2221-
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2222-
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2220+
if (DisableFOpTreeReduce)
22232221
return SDValue();
22242222

2225-
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2223+
SDLoc DL(Op);
2224+
const SDNodeFlags Flags = Op->getFlags();
2225+
const SDValue &Vector = Op.getOperand(0);
2226+
EVT EltTy = Vector.getValueType().getVectorElementType();
22262227
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22272228
STI.getPTXVersion() >= 88;
2228-
SDLoc DL(Op);
2229-
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2229+
2230+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2231+
// number of inputs they take.
2232+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2233+
bool IsReassociatable;
2234+
22302235
switch (Op->getOpcode()) {
22312236
case ISD::VECREDUCE_FADD:
2232-
Operators = {{ISD::FADD, 2}};
2237+
ScalarOps = {{ISD::FADD, 2}};
2238+
IsReassociatable = false;
22332239
break;
22342240
case ISD::VECREDUCE_FMUL:
2235-
Operators = {{ISD::FMUL, 2}};
2241+
ScalarOps = {{ISD::FMUL, 2}};
2242+
IsReassociatable = false;
22362243
break;
22372244
case ISD::VECREDUCE_FMAX:
22382245
if (CanUseMinMax3)
2239-
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2240-
Operators.push_back({ISD::FMAXNUM, 2});
2246+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2247+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2248+
IsReassociatable = false;
22412249
break;
22422250
case ISD::VECREDUCE_FMIN:
22432251
if (CanUseMinMax3)
2244-
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2245-
Operators.push_back({ISD::FMINNUM, 2});
2252+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2253+
ScalarOps.push_back({ISD::FMINNUM, 2});
2254+
IsReassociatable = false;
22462255
break;
22472256
case ISD::VECREDUCE_FMAXIMUM:
22482257
if (CanUseMinMax3)
2249-
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2250-
Operators.push_back({ISD::FMAXIMUM, 2});
2258+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2259+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2260+
IsReassociatable = false;
22512261
break;
22522262
case ISD::VECREDUCE_FMINIMUM:
22532263
if (CanUseMinMax3)
2254-
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2255-
Operators.push_back({ISD::FMINIMUM, 2});
2264+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2265+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2266+
IsReassociatable = false;
2267+
break;
2268+
case ISD::VECREDUCE_ADD:
2269+
ScalarOps = {{ISD::ADD, 2}};
2270+
IsReassociatable = true;
2271+
break;
2272+
case ISD::VECREDUCE_MUL:
2273+
ScalarOps = {{ISD::MUL, 2}};
2274+
IsReassociatable = true;
2275+
break;
2276+
case ISD::VECREDUCE_UMAX:
2277+
ScalarOps = {{ISD::UMAX, 2}};
2278+
IsReassociatable = true;
2279+
break;
2280+
case ISD::VECREDUCE_UMIN:
2281+
ScalarOps = {{ISD::UMIN, 2}};
2282+
IsReassociatable = true;
2283+
break;
2284+
case ISD::VECREDUCE_SMAX:
2285+
ScalarOps = {{ISD::SMAX, 2}};
2286+
IsReassociatable = true;
2287+
break;
2288+
case ISD::VECREDUCE_SMIN:
2289+
ScalarOps = {{ISD::SMIN, 2}};
2290+
IsReassociatable = true;
2291+
break;
2292+
case ISD::VECREDUCE_AND:
2293+
ScalarOps = {{ISD::AND, 2}};
2294+
IsReassociatable = true;
2295+
break;
2296+
case ISD::VECREDUCE_OR:
2297+
ScalarOps = {{ISD::OR, 2}};
2298+
IsReassociatable = true;
2299+
break;
2300+
case ISD::VECREDUCE_XOR:
2301+
ScalarOps = {{ISD::XOR, 2}};
2302+
IsReassociatable = true;
22562303
break;
22572304
default:
22582305
llvm_unreachable("unhandled vecreduce operation");
22592306
}
22602307

2261-
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2262-
DAG);
2308+
EVT VectorTy = Vector.getValueType();
2309+
const unsigned NumElts = VectorTy.getVectorNumElements();
2310+
2311+
// scalarize vector
2312+
SmallVector<SDValue> Elements(NumElts);
2313+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2314+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2315+
DAG.getConstant(I, DL, MVT::i64));
2316+
}
2317+
2318+
// Lower to tree reduction.
2319+
if (IsReassociatable || Flags.hasAllowReassociation())
2320+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2321+
2322+
// Lower to sequential reduction.
2323+
SDValue Accumulator;
2324+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2325+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2326+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2327+
2328+
if (!Accumulator) {
2329+
if (I + DefaultGroupSize <= NumElts) {
2330+
Accumulator = DAG.getNode(
2331+
DefaultScalarOp, DL, EltTy,
2332+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2333+
I += DefaultGroupSize;
2334+
}
2335+
}
2336+
2337+
if (Accumulator) {
2338+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2339+
SmallVector<SDValue> Operands = {Accumulator};
2340+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2341+
Operands.push_back(Elements[I + K]);
2342+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2343+
}
2344+
}
2345+
}
2346+
2347+
return Accumulator;
22632348
}
22642349

22652350
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
@@ -3056,6 +3141,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30563141
case ISD::VECREDUCE_FMIN:
30573142
case ISD::VECREDUCE_FMAXIMUM:
30583143
case ISD::VECREDUCE_FMINIMUM:
3144+
case ISD::VECREDUCE_ADD:
3145+
case ISD::VECREDUCE_MUL:
3146+
case ISD::VECREDUCE_UMAX:
3147+
case ISD::VECREDUCE_UMIN:
3148+
case ISD::VECREDUCE_SMAX:
3149+
case ISD::VECREDUCE_SMIN:
3150+
case ISD::VECREDUCE_AND:
3151+
case ISD::VECREDUCE_OR:
3152+
case ISD::VECREDUCE_XOR:
30593153
return LowerVECREDUCE(Op, DAG);
30603154
case ISD::STORE:
30613155
return LowerSTORE(Op, DAG);

0 commit comments

Comments
 (0)