Skip to content

Commit aed0f4a

Browse files
committed
[NVPTX] support rest of VECREDUCE intrinsics and other improvements
- Support all VECREDUCE intrinsics - Clean up FileCheck directives in lit test - Also handle sequential lowering in NVPTX backend, where we can still use larger operations.
1 parent 1c98e5d commit aed0f4a

File tree

2 files changed

+1718
-410
lines changed

2 files changed

+1718
-410
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 135 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -835,12 +835,22 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
835835
setTargetDAGCombine(ISD::SETCC);
836836

837837
// Vector reduction operations. These are transformed into a tree evaluation
838-
// of nodes which may or may not be legal.
838+
// of nodes which may initially be illegal.
839839
for (MVT VT : MVT::fixedlen_vector_valuetypes()) {
840-
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841-
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
842-
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
843-
VT, Custom);
840+
MVT EltVT = VT.getVectorElementType();
841+
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842+
EltVT == MVT::f64) {
843+
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844+
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845+
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846+
VT, Custom);
847+
} else if (EltVT.isScalarInteger()) {
848+
setOperationAction(
849+
{ISD::VECREDUCE_ADD, ISD::VECREDUCE_MUL, ISD::VECREDUCE_AND,
850+
ISD::VECREDUCE_OR, ISD::VECREDUCE_XOR, ISD::VECREDUCE_SMAX,
851+
ISD::VECREDUCE_SMIN, ISD::VECREDUCE_UMAX, ISD::VECREDUCE_UMIN},
852+
VT, Custom);
853+
}
844854
}
845855

846856
// Promote fp16 arithmetic if fp16 hardware isn't available or the
@@ -2147,29 +2157,17 @@ NVPTXTargetLowering::LowerCONCAT_VECTORS(SDValue Op, SelectionDAG &DAG) const {
21472157
return DAG.getBuildVector(Node->getValueType(0), dl, Ops);
21482158
}
21492159

2150-
/// A generic routine for constructing a tree reduction for a vector operand.
2160+
/// A generic routine for constructing a tree reduction on a vector operand.
21512161
/// This method differs from iterative splitting in DAGTypeLegalizer by
2152-
/// first scalarizing the vector and then progressively grouping elements
2153-
/// bottom-up. This allows easily building the optimal (minimum) number of nodes
2154-
/// with different numbers of operands (eg. max3 vs max2).
2162+
/// progressively grouping elements bottom-up.
21552163
static SDValue BuildTreeReduction(
2156-
const SDValue &VectorOp,
2164+
const SmallVector<SDValue> &Elements, EVT EltTy,
21572165
ArrayRef<std::pair<unsigned /*NodeType*/, unsigned /*NumInputs*/>> Ops,
21582166
const SDLoc &DL, const SDNodeFlags Flags, SelectionDAG &DAG) {
2159-
EVT VectorTy = VectorOp.getValueType();
2160-
EVT EltTy = VectorTy.getVectorElementType();
2161-
const unsigned NumElts = VectorTy.getVectorNumElements();
2162-
2163-
// scalarize vector
2164-
SmallVector<SDValue> Elements(NumElts);
2165-
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2166-
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, VectorOp,
2167-
DAG.getConstant(I, DL, MVT::i64));
2168-
}
2169-
21702167
// now build the computation graph in place at each level
21712168
SmallVector<SDValue> Level = Elements;
2172-
for (unsigned OpIdx = 0; Level.size() > 1 && OpIdx < Ops.size();) {
2169+
unsigned OpIdx = 0;
2170+
while (Level.size() > 1) {
21732171
const auto [DefaultScalarOp, DefaultGroupSize] = Ops[OpIdx];
21742172

21752173
// partially reduce all elements in level
@@ -2201,52 +2199,139 @@ static SDValue BuildTreeReduction(
22012199
return *Level.begin();
22022200
}
22032201

2204-
/// Lower fadd/fmul vector reductions. Builds a computation graph (tree) and
2205-
/// serializes it.
2202+
/// Lower reductions to either a sequence of operations or a tree if
2203+
/// reassociations are allowed. This method will use larger operations like
2204+
/// max3/min3 when the target supports them.
22062205
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22072206
SelectionDAG &DAG) const {
2208-
// If we can't reorder sub-operations, let DAGTypeLegalizer lower this op.
2209-
if (DisableFOpTreeReduce || !Op->getFlags().hasAllowReassociation())
2207+
if (DisableFOpTreeReduce)
22102208
return SDValue();
22112209

2212-
EVT EltTy = Op.getOperand(0).getValueType().getVectorElementType();
2210+
SDLoc DL(Op);
2211+
const SDNodeFlags Flags = Op->getFlags();
2212+
const SDValue &Vector = Op.getOperand(0);
2213+
EVT EltTy = Vector.getValueType().getVectorElementType();
22132214
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22142215
STI.getPTXVersion() >= 88;
2215-
SDLoc DL(Op);
2216-
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> Operators;
2216+
2217+
// A list of SDNode opcodes with equivalent semantics, sorted descending by
2218+
// number of inputs they take.
2219+
SmallVector<std::pair<unsigned /*Op*/, unsigned /*NumIn*/>, 2> ScalarOps;
2220+
bool IsReassociatable;
2221+
22172222
switch (Op->getOpcode()) {
22182223
case ISD::VECREDUCE_FADD:
2219-
Operators = {{ISD::FADD, 2}};
2224+
ScalarOps = {{ISD::FADD, 2}};
2225+
IsReassociatable = false;
22202226
break;
22212227
case ISD::VECREDUCE_FMUL:
2222-
Operators = {{ISD::FMUL, 2}};
2228+
ScalarOps = {{ISD::FMUL, 2}};
2229+
IsReassociatable = false;
22232230
break;
22242231
case ISD::VECREDUCE_FMAX:
22252232
if (CanUseMinMax3)
2226-
Operators.push_back({NVPTXISD::FMAXNUM3, 3});
2227-
Operators.push_back({ISD::FMAXNUM, 2});
2233+
ScalarOps.push_back({NVPTXISD::FMAXNUM3, 3});
2234+
ScalarOps.push_back({ISD::FMAXNUM, 2});
2235+
IsReassociatable = false;
22282236
break;
22292237
case ISD::VECREDUCE_FMIN:
22302238
if (CanUseMinMax3)
2231-
Operators.push_back({NVPTXISD::FMINNUM3, 3});
2232-
Operators.push_back({ISD::FMINNUM, 2});
2239+
ScalarOps.push_back({NVPTXISD::FMINNUM3, 3});
2240+
ScalarOps.push_back({ISD::FMINNUM, 2});
2241+
IsReassociatable = false;
22332242
break;
22342243
case ISD::VECREDUCE_FMAXIMUM:
22352244
if (CanUseMinMax3)
2236-
Operators.push_back({NVPTXISD::FMAXIMUM3, 3});
2237-
Operators.push_back({ISD::FMAXIMUM, 2});
2245+
ScalarOps.push_back({NVPTXISD::FMAXIMUM3, 3});
2246+
ScalarOps.push_back({ISD::FMAXIMUM, 2});
2247+
IsReassociatable = false;
22382248
break;
22392249
case ISD::VECREDUCE_FMINIMUM:
22402250
if (CanUseMinMax3)
2241-
Operators.push_back({NVPTXISD::FMINIMUM3, 3});
2242-
Operators.push_back({ISD::FMINIMUM, 2});
2251+
ScalarOps.push_back({NVPTXISD::FMINIMUM3, 3});
2252+
ScalarOps.push_back({ISD::FMINIMUM, 2});
2253+
IsReassociatable = false;
2254+
break;
2255+
case ISD::VECREDUCE_ADD:
2256+
ScalarOps = {{ISD::ADD, 2}};
2257+
IsReassociatable = true;
2258+
break;
2259+
case ISD::VECREDUCE_MUL:
2260+
ScalarOps = {{ISD::MUL, 2}};
2261+
IsReassociatable = true;
2262+
break;
2263+
case ISD::VECREDUCE_UMAX:
2264+
ScalarOps = {{ISD::UMAX, 2}};
2265+
IsReassociatable = true;
2266+
break;
2267+
case ISD::VECREDUCE_UMIN:
2268+
ScalarOps = {{ISD::UMIN, 2}};
2269+
IsReassociatable = true;
2270+
break;
2271+
case ISD::VECREDUCE_SMAX:
2272+
ScalarOps = {{ISD::SMAX, 2}};
2273+
IsReassociatable = true;
2274+
break;
2275+
case ISD::VECREDUCE_SMIN:
2276+
ScalarOps = {{ISD::SMIN, 2}};
2277+
IsReassociatable = true;
2278+
break;
2279+
case ISD::VECREDUCE_AND:
2280+
ScalarOps = {{ISD::AND, 2}};
2281+
IsReassociatable = true;
2282+
break;
2283+
case ISD::VECREDUCE_OR:
2284+
ScalarOps = {{ISD::OR, 2}};
2285+
IsReassociatable = true;
2286+
break;
2287+
case ISD::VECREDUCE_XOR:
2288+
ScalarOps = {{ISD::XOR, 2}};
2289+
IsReassociatable = true;
22432290
break;
22442291
default:
22452292
llvm_unreachable("unhandled vecreduce operation");
22462293
}
22472294

2248-
return BuildTreeReduction(Op.getOperand(0), Operators, DL, Op->getFlags(),
2249-
DAG);
2295+
EVT VectorTy = Vector.getValueType();
2296+
const unsigned NumElts = VectorTy.getVectorNumElements();
2297+
2298+
// scalarize vector
2299+
SmallVector<SDValue> Elements(NumElts);
2300+
for (unsigned I = 0, E = NumElts; I != E; ++I) {
2301+
Elements[I] = DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, EltTy, Vector,
2302+
DAG.getConstant(I, DL, MVT::i64));
2303+
}
2304+
2305+
// Lower to tree reduction.
2306+
if (IsReassociatable || Flags.hasAllowReassociation())
2307+
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2308+
2309+
// Lower to sequential reduction.
2310+
SDValue Accumulator;
2311+
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
2312+
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
2313+
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
2314+
2315+
if (!Accumulator) {
2316+
if (I + DefaultGroupSize <= NumElts) {
2317+
Accumulator = DAG.getNode(
2318+
DefaultScalarOp, DL, EltTy,
2319+
ArrayRef(Elements).slice(I, I + DefaultGroupSize), Flags);
2320+
I += DefaultGroupSize;
2321+
}
2322+
}
2323+
2324+
if (Accumulator) {
2325+
for (; I + (DefaultGroupSize - 1) <= NumElts; I += DefaultGroupSize - 1) {
2326+
SmallVector<SDValue> Operands = {Accumulator};
2327+
for (unsigned K = 0; K < DefaultGroupSize - 1; ++K)
2328+
Operands.push_back(Elements[I + K]);
2329+
Accumulator = DAG.getNode(DefaultScalarOp, DL, EltTy, Operands, Flags);
2330+
}
2331+
}
2332+
}
2333+
2334+
return Accumulator;
22502335
}
22512336

22522337
SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
@@ -3032,6 +3117,15 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30323117
case ISD::VECREDUCE_FMIN:
30333118
case ISD::VECREDUCE_FMAXIMUM:
30343119
case ISD::VECREDUCE_FMINIMUM:
3120+
case ISD::VECREDUCE_ADD:
3121+
case ISD::VECREDUCE_MUL:
3122+
case ISD::VECREDUCE_UMAX:
3123+
case ISD::VECREDUCE_UMIN:
3124+
case ISD::VECREDUCE_SMAX:
3125+
case ISD::VECREDUCE_SMIN:
3126+
case ISD::VECREDUCE_AND:
3127+
case ISD::VECREDUCE_OR:
3128+
case ISD::VECREDUCE_XOR:
30353129
return LowerVECREDUCE(Op, DAG);
30363130
case ISD::STORE:
30373131
return LowerSTORE(Op, DAG);

0 commit comments

Comments
 (0)