Skip to content

Commit 59a0fc2

Browse files
committed
[NVPTX] support rest of VECREDUCE intrinsics and other improvements
- Support all VECREDUCE intrinsics - Clean up FileCheck directives in lit test - Also handle sequential lowering in NVPTX backend, where we can still use larger operations.
1 parent 7c58476 commit 59a0fc2

File tree

2 files changed

+979
-1830
lines changed

2 files changed

+979
-1830
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -870,12 +870,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
870870
setTargetDAGCombine(ISD::SETCC);
871871

872872
// Vector reduction operations. These are transformed into a tree evaluation
873-
// of nodes which may or may not be legal.
873+
// of nodes which may initially be illegal.
874874
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
875-
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
876-
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
877-
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
878-
VT, Custom);
875+
MVT EltVT = VT.getVectorElementType();
876+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
877+
EltVT == MVT::f64) {
878+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
879+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
880+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
881+
VT, Custom);
882+
} else if (EltVT.isScalarInteger()) {
883+
setOperationAction(
884+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
885+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
886+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
887+
VT, Custom);
888+
}
879889
}
880890

881891
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2213,29 +2223,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
22132223
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
22142224
}
22152225

2216-
/// A generic routine for constructing a tree reduction for a vector operand.
2226+
/// A generic routine for constructing a tree reduction on a vector operand.
22172227
/// This method differs from iterative splitting in DAGTypeLegalizer by
2218-
/// first scalarizing the vector and then progressively grouping elements
2219-
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2220-
/// with different numbers of operands (eg. max3 vs max2).
2228+
/// progressively grouping elements bottom-up.
22212229
static SDValue BuildTreeReduction(
2222-
const SDValue &VectorOp,
2230+
const SmallVector<SDValue> &Elements, EVT EltTy,
22232231
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
22242232
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2225-
EVT VectorTy = VectorOp.getValueType();
2226-
EVT EltTy = VectorTy.getVectorElementType();
2227-
const unsigned NumElts = VectorTy.getVectorNumElements();
2228-
2229-
// scalarize vector
2230-
SmallVector<SDValue> Elements(NumElts);
2231-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2232-
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2233-
DAG.getConstant(I, DL, MVT::i64));
2234-
}
2235-
22362233
// now build the computation graph in place at each level
22372234
SmallVector<SDValue> Level = Elements;
2238-
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2235+
unsigned OpIdx = 0;
2236+
while (Level.size() > 1) {
22392237
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
22402238

22412239
// partially reduce all elements in level
@@ -2267,52 +2265,139 @@ static SDValue BuildTreeReduction(
22672265
return *Level.begin();
22682266
}
22692267

2270-
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2271-
/// serializes it.
2268+
/// Lower reductions to either a sequence of operations or a tree if
2269+
/// reassociations are allowed. This method will use larger operations like
2270+
/// max3/min3 when the target supports them.
22722271
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22732272
SelectionDAG &DAG) const {
2274-
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2275-
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2273+
if (DisableFOpTreeReduce)
22762274
return SDValue();
22772275

2278-
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2276+
SDLoc DL(Op);
2277+
const SDNodeFlags Flags = Op->getFlags();
2278+
const SDValue &Vector = Op.getOperand(0);
2279+
EVT EltTy = Vector.getValueType().getVectorElementType();
22792280
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22802281
STI.getPTXVersion() >= 88;
2281-
SDLoc DL(Op);
2282-
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2282+
2283+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2284+
// number of inputs they take.
2285+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2286+
bool IsReassociatable;
2287+
22832288
switch (Op->getOpcode()) {
22842289
case ISD::VECREDUCE_FADD:
2285-
Operators = {{ISD::FADD, 2}};
2290+
ScalarOps = {{ISD::FADD, 2}};
2291+
IsReassociatable = false;
22862292
break;
22872293
case ISD::VECREDUCE_FMUL:
2288-
Operators = {{ISD::FMUL, 2}};
2294+
ScalarOps = {{ISD::FMUL, 2}};
2295+
IsReassociatable = false;
22892296
break;
22902297
case ISD::VECREDUCE_FMAX:
22912298
if (CanUseMinMax3)
2292-
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2293-
Operators.push_back({ISD::FMAXNUM, 2});
2299+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2300+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2301+
IsReassociatable = false;
22942302
break;
22952303
case ISD::VECREDUCE_FMIN:
22962304
if (CanUseMinMax3)
2297-
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2298-
Operators.push_back({ISD::FMINNUM, 2});
2305+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2306+
ScalarOps.push_back({ISD::FMINNUM, 2});
2307+
IsReassociatable = false;
22992308
break;
23002309
case ISD::VECREDUCE_FMAXIMUM:
23012310
if (CanUseMinMax3)
2302-
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2303-
Operators.push_back({ISD::FMAXIMUM, 2});
2311+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2312+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2313+
IsReassociatable = false;
23042314
break;
23052315
case ISD::VECREDUCE_FMINIMUM:
23062316
if (CanUseMinMax3)
2307-
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2308-
Operators.push_back({ISD::FMINIMUM, 2});
2317+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2318+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2319+
IsReassociatable = false;
2320+
break;
2321+
case ISD::VECREDUCE_ADD:
2322+
ScalarOps = {{ISD::ADD, 2}};
2323+
IsReassociatable = true;
2324+
break;
2325+
case ISD::VECREDUCE_MUL:
2326+
ScalarOps = {{ISD::MUL, 2}};
2327+
IsReassociatable = true;
2328+
break;
2329+
case ISD::VECREDUCE_UMAX:
2330+
ScalarOps = {{ISD::UMAX, 2}};
2331+
IsReassociatable = true;
2332+
break;
2333+
case ISD::VECREDUCE_UMIN:
2334+
ScalarOps = {{ISD::UMIN, 2}};
2335+
IsReassociatable = true;
2336+
break;
2337+
case ISD::VECREDUCE_SMAX:
2338+
ScalarOps = {{ISD::SMAX, 2}};
2339+
IsReassociatable = true;
2340+
break;
2341+
case ISD::VECREDUCE_SMIN:
2342+
ScalarOps = {{ISD::SMIN, 2}};
2343+
IsReassociatable = true;
2344+
break;
2345+
case ISD::VECREDUCE_AND:
2346+
ScalarOps = {{ISD::AND, 2}};
2347+
IsReassociatable = true;
2348+
break;
2349+
case ISD::VECREDUCE_OR:
2350+
ScalarOps = {{ISD::OR, 2}};
2351+
IsReassociatable = true;
2352+
break;
2353+
case ISD::VECREDUCE_XOR:
2354+
ScalarOps = {{ISD::XOR, 2}};
2355+
IsReassociatable = true;
23092356
break;
23102357
default:
23112358
llvm_unreachable("unhandled vecreduce operation");
23122359
}
23132360

2314-
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2315-
DAG);
2361+
EVT VectorTy = Vector.getValueType();
2362+
const unsigned NumElts = VectorTy.getVectorNumElements();
2363+
2364+
// scalarize vector
2365+
SmallVector<SDValue> Elements(NumElts);
2366+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2367+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2368+
DAG.getConstant(I, DL, MVT::i64));
2369+
}
2370+
2371+
// Lower to tree reduction.
2372+
if (IsReassociatable || Flags.hasAllowReassociation())
2373+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2374+
2375+
// Lower to sequential reduction.
2376+
SDValue Accumulator;
2377+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2378+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2379+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2380+
2381+
if (!Accumulator) {
2382+
if (I + DefaultGroupSize <= NumElts) {
2383+
Accumulator = DAG.getNode(
2384+
DefaultScalarOp, DL, EltTy,
2385+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2386+
I += DefaultGroupSize;
2387+
}
2388+
}
2389+
2390+
if (Accumulator) {
2391+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2392+
SmallVector<SDValue> Operands = {Accumulator};
2393+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2394+
Operands.push_back(Elements[I + K]);
2395+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2396+
}
2397+
}
2398+
}
2399+
2400+
return Accumulator;
23162401
}
23172402

23182403
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
@@ -3153,6 +3238,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31533238
case ISD::VECREDUCE_FMIN:
31543239
case ISD::VECREDUCE_FMAXIMUM:
31553240
case ISD::VECREDUCE_FMINIMUM:
3241+
case ISD::VECREDUCE_ADD:
3242+
case ISD::VECREDUCE_MUL:
3243+
case ISD::VECREDUCE_UMAX:
3244+
case ISD::VECREDUCE_UMIN:
3245+
case ISD::VECREDUCE_SMAX:
3246+
case ISD::VECREDUCE_SMIN:
3247+
case ISD::VECREDUCE_AND:
3248+
case ISD::VECREDUCE_OR:
3249+
case ISD::VECREDUCE_XOR:
31563250
return LowerVECREDUCE(Op, DAG);
31573251
case ISD::STORE:
31583252
return LowerSTORE(Op, DAG);

0 commit comments

Comments
 (0)