Skip to content

Commit 6ac46e5

Browse files
committed
[NVPTX] support rest of VECREDUCE intrinsics and other improvements
- Support all VECREDUCE intrinsics - Clean up FileCheck directives in lit test - Also handle sequential lowering in NVPTX backend, where we can still use larger operations.
1 parent 92f6f6d commit 6ac46e5

File tree

2 files changed

+1665
-184
lines changed

2 files changed

+1665
-184
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -841,12 +841,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841841
setTargetDAGCombine(ISD::SETCC);
842842

843843
// Vector reduction operations. These are transformed into a tree evaluation
844-
// of nodes which may or may not be legal.
844+
// of nodes which may initially be illegal.
845845
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
846-
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
847-
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848-
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849-
VT, Custom);
846+
MVT EltVT = VT.getVectorElementType();
847+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
848+
EltVT == MVT::f64) {
849+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
850+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
851+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
852+
VT, Custom);
853+
} else if (EltVT.isScalarInteger()) {
854+
setOperationAction(
855+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
856+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
857+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
858+
VT, Custom);
859+
}
850860
}
851861

852862
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2166,29 +2176,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21662176
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21672177
}
21682178

2169-
/// A generic routine for constructing a tree reduction for a vector operand.
2179+
/// A generic routine for constructing a tree reduction on a vector operand.
21702180
/// This method differs from iterative splitting in DAGTypeLegalizer by
2171-
/// first scalarizing the vector and then progressively grouping elements
2172-
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2173-
/// with different numbers of operands (eg. max3 vs max2).
2181+
/// progressively grouping elements bottom-up.
21742182
static SDValue BuildTreeReduction(
2175-
const SDValue &VectorOp,
2183+
const SmallVector<SDValue> &Elements, EVT EltTy,
21762184
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21772185
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2178-
EVT VectorTy = VectorOp.getValueType();
2179-
EVT EltTy = VectorTy.getVectorElementType();
2180-
const unsigned NumElts = VectorTy.getVectorNumElements();
2181-
2182-
// scalarize vector
2183-
SmallVector<SDValue> Elements(NumElts);
2184-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2185-
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2186-
DAG.getConstant(I, DL, MVT::i64));
2187-
}
2188-
21892186
// now build the computation graph in place at each level
21902187
SmallVector<SDValue> Level = Elements;
2191-
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2188+
unsigned OpIdx = 0;
2189+
while (Level.size() > 1) {
21922190
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21932191

21942192
// partially reduce all elements in level
@@ -2220,52 +2218,139 @@ static SDValue BuildTreeReduction(
22202218
return *Level.begin();
22212219
}
22222220

2223-
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2224-
/// serializes it.
2221+
/// Lower reductions to either a sequence of operations or a tree if
2222+
/// reassociations are allowed. This method will use larger operations like
2223+
/// max3/min3 when the target supports them.
22252224
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22262225
SelectionDAG &DAG) const {
2227-
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2228-
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2226+
if (DisableFOpTreeReduce)
22292227
return SDValue();
22302228

2231-
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2229+
SDLoc DL(Op);
2230+
const SDNodeFlags Flags = Op->getFlags();
2231+
const SDValue &Vector = Op.getOperand(0);
2232+
EVT EltTy = Vector.getValueType().getVectorElementType();
22322233
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22332234
STI.getPTXVersion() >= 88;
2234-
SDLoc DL(Op);
2235-
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2235+
2236+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2237+
// number of inputs they take.
2238+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2239+
bool IsReassociatable;
2240+
22362241
switch (Op->getOpcode()) {
22372242
case ISD::VECREDUCE_FADD:
2238-
Operators = {{ISD::FADD, 2}};
2243+
ScalarOps = {{ISD::FADD, 2}};
2244+
IsReassociatable = false;
22392245
break;
22402246
case ISD::VECREDUCE_FMUL:
2241-
Operators = {{ISD::FMUL, 2}};
2247+
ScalarOps = {{ISD::FMUL, 2}};
2248+
IsReassociatable = false;
22422249
break;
22432250
case ISD::VECREDUCE_FMAX:
22442251
if (CanUseMinMax3)
2245-
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2246-
Operators.push_back({ISD::FMAXNUM, 2});
2252+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2253+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2254+
IsReassociatable = false;
22472255
break;
22482256
case ISD::VECREDUCE_FMIN:
22492257
if (CanUseMinMax3)
2250-
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2251-
Operators.push_back({ISD::FMINNUM, 2});
2258+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2259+
ScalarOps.push_back({ISD::FMINNUM, 2});
2260+
IsReassociatable = false;
22522261
break;
22532262
case ISD::VECREDUCE_FMAXIMUM:
22542263
if (CanUseMinMax3)
2255-
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2256-
Operators.push_back({ISD::FMAXIMUM, 2});
2264+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2265+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2266+
IsReassociatable = false;
22572267
break;
22582268
case ISD::VECREDUCE_FMINIMUM:
22592269
if (CanUseMinMax3)
2260-
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2261-
Operators.push_back({ISD::FMINIMUM, 2});
2270+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2271+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2272+
IsReassociatable = false;
2273+
break;
2274+
case ISD::VECREDUCE_ADD:
2275+
ScalarOps = {{ISD::ADD, 2}};
2276+
IsReassociatable = true;
2277+
break;
2278+
case ISD::VECREDUCE_MUL:
2279+
ScalarOps = {{ISD::MUL, 2}};
2280+
IsReassociatable = true;
2281+
break;
2282+
case ISD::VECREDUCE_UMAX:
2283+
ScalarOps = {{ISD::UMAX, 2}};
2284+
IsReassociatable = true;
2285+
break;
2286+
case ISD::VECREDUCE_UMIN:
2287+
ScalarOps = {{ISD::UMIN, 2}};
2288+
IsReassociatable = true;
2289+
break;
2290+
case ISD::VECREDUCE_SMAX:
2291+
ScalarOps = {{ISD::SMAX, 2}};
2292+
IsReassociatable = true;
2293+
break;
2294+
case ISD::VECREDUCE_SMIN:
2295+
ScalarOps = {{ISD::SMIN, 2}};
2296+
IsReassociatable = true;
2297+
break;
2298+
case ISD::VECREDUCE_AND:
2299+
ScalarOps = {{ISD::AND, 2}};
2300+
IsReassociatable = true;
2301+
break;
2302+
case ISD::VECREDUCE_OR:
2303+
ScalarOps = {{ISD::OR, 2}};
2304+
IsReassociatable = true;
2305+
break;
2306+
case ISD::VECREDUCE_XOR:
2307+
ScalarOps = {{ISD::XOR, 2}};
2308+
IsReassociatable = true;
22622309
break;
22632310
default:
22642311
llvm_unreachable("unhandled vecreduce operation");
22652312
}
22662313

2267-
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2268-
DAG);
2314+
EVT VectorTy = Vector.getValueType();
2315+
const unsigned NumElts = VectorTy.getVectorNumElements();
2316+
2317+
// scalarize vector
2318+
SmallVector<SDValue> Elements(NumElts);
2319+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2320+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2321+
DAG.getConstant(I, DL, MVT::i64));
2322+
}
2323+
2324+
// Lower to tree reduction.
2325+
if (IsReassociatable || Flags.hasAllowReassociation())
2326+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2327+
2328+
// Lower to sequential reduction.
2329+
SDValue Accumulator;
2330+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2331+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2332+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2333+
2334+
if (!Accumulator) {
2335+
if (I + DefaultGroupSize <= NumElts) {
2336+
Accumulator = DAG.getNode(
2337+
DefaultScalarOp, DL, EltTy,
2338+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2339+
I += DefaultGroupSize;
2340+
}
2341+
}
2342+
2343+
if (Accumulator) {
2344+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2345+
SmallVector<SDValue> Operands = {Accumulator};
2346+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2347+
Operands.push_back(Elements[I + K]);
2348+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2349+
}
2350+
}
2351+
}
2352+
2353+
return Accumulator;
22692354
}
22702355

22712356
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
@@ -3062,6 +3147,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30623147
case ISD::VECREDUCE_FMIN:
30633148
case ISD::VECREDUCE_FMAXIMUM:
30643149
case ISD::VECREDUCE_FMINIMUM:
3150+
case ISD::VECREDUCE_ADD:
3151+
case ISD::VECREDUCE_MUL:
3152+
case ISD::VECREDUCE_UMAX:
3153+
case ISD::VECREDUCE_UMIN:
3154+
case ISD::VECREDUCE_SMAX:
3155+
case ISD::VECREDUCE_SMIN:
3156+
case ISD::VECREDUCE_AND:
3157+
case ISD::VECREDUCE_OR:
3158+
case ISD::VECREDUCE_XOR:
30653159
return LowerVECREDUCE(Op, DAG);
30663160
case ISD::STORE:
30673161
return LowerSTORE(Op, DAG);

0 commit comments

Comments
 (0)