Skip to content

Commit c51df14

Browse files
committed
[NVPTX] support VECREDUCE_SEQ ops and remove option
1 parent ec897d0 commit c51df14

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<bool> DisableFOpTreeReduce(
89-
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90-
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91-
"reduction operations"),
92-
cl::init(false));
93-
9488
static cl::opt<int> UsePrecDivF32(
9589
"nvptx-prec-divf32", cl::Hidden,
9690
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -841,6 +835,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
841835
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
842836
EltVT == MVT::f64) {
843837
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
838+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
844839
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
845840
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
846841
VT, Custom);
@@ -2217,12 +2212,19 @@ static SDValue BuildTreeReduction(
22172212
/// max3/min3 when the target supports them.
22182213
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22192214
SelectionDAG &DAG) const {
2220-
if (DisableFOpTreeReduce)
2221-
return SDValue();
2222-
22232215
SDLoc DL(Op);
22242216
const SDNodeFlags Flags = Op->getFlags();
2225-
const SDValue &Vector = Op.getOperand(0);
2217+
SDValue Vector;
2218+
SDValue Accumulator;
2219+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2220+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2221+
// special case with accumulator as first arg
2222+
Accumulator = Op.getOperand(0);
2223+
Vector = Op.getOperand(1);
2224+
} else {
2225+
// default case
2226+
Vector = Op.getOperand(0);
2227+
}
22262228
EVT EltTy = Vector.getValueType().getVectorElementType();
22272229
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22282230
STI.getPTXVersion() >= 88;
@@ -2234,10 +2236,12 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22342236

22352237
switch (Op->getOpcode()) {
22362238
case ISD::VECREDUCE_FADD:
2239+
case ISD::VECREDUCE_SEQ_FADD:
22372240
ScalarOps = {{ISD::FADD, 2}};
22382241
IsReassociatable = false;
22392242
break;
22402243
case ISD::VECREDUCE_FMUL:
2244+
case ISD::VECREDUCE_SEQ_FMUL:
22412245
ScalarOps = {{ISD::FMUL, 2}};
22422246
IsReassociatable = false;
22432247
break;
@@ -2316,11 +2320,13 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23162320
}
23172321

23182322
// Lower to tree reduction.
2319-
if (IsReassociatable || Flags.hasAllowReassociation())
2323+
if (IsReassociatable || Flags.hasAllowReassociation()) {
2324+
// we don't expect an accumulator for reassociatable vector reduction ops
2325+
assert(!Accumulator && "unexpected accumulator");
23202326
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2327+
}
23212328

23222329
// Lower to sequential reduction.
2323-
SDValue Accumulator;
23242330
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
23252331
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23262332
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
@@ -3137,6 +3143,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31373143
return LowerCONCAT_VECTORS(Op, DAG);
31383144
case ISD::VECREDUCE_FADD:
31393145
case ISD::VECREDUCE_FMUL:
3146+
case ISD::VECREDUCE_SEQ_FADD:
3147+
case ISD::VECREDUCE_SEQ_FMUL:
31403148
case ISD::VECREDUCE_FMAX:
31413149
case ISD::VECREDUCE_FMIN:
31423150
case ISD::VECREDUCE_FMAXIMUM:

llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ define half @reduce_fadd_half(<8 x half> %in) {
2323
; CHECK-EMPTY:
2424
; CHECK-NEXT: // %bb.0:
2525
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
26-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
27-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
28-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
29-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
30-
; CHECK-NEXT: mov.b16 %rs9, 0x0000;
31-
; CHECK-NEXT: add.rn.f16 %rs10, %rs7, %rs9;
32-
; CHECK-NEXT: add.rn.f16 %rs11, %rs10, %rs8;
33-
; CHECK-NEXT: add.rn.f16 %rs12, %rs11, %rs5;
34-
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs6;
35-
; CHECK-NEXT: add.rn.f16 %rs14, %rs13, %rs3;
36-
; CHECK-NEXT: add.rn.f16 %rs15, %rs14, %rs4;
37-
; CHECK-NEXT: add.rn.f16 %rs16, %rs15, %rs1;
38-
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs2;
26+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
27+
; CHECK-NEXT: mov.b16 %rs3, 0x0000;
28+
; CHECK-NEXT: add.rn.f16 %rs4, %rs1, %rs3;
29+
; CHECK-NEXT: add.rn.f16 %rs5, %rs4, %rs2;
30+
; CHECK-NEXT: mov.b32 {%rs6, %rs7}, %r2;
31+
; CHECK-NEXT: add.rn.f16 %rs8, %rs5, %rs6;
32+
; CHECK-NEXT: add.rn.f16 %rs9, %rs8, %rs7;
33+
; CHECK-NEXT: mov.b32 {%rs10, %rs11}, %r3;
34+
; CHECK-NEXT: add.rn.f16 %rs12, %rs9, %rs10;
35+
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs11;
36+
; CHECK-NEXT: mov.b32 {%rs14, %rs15}, %r4;
37+
; CHECK-NEXT: add.rn.f16 %rs16, %rs13, %rs14;
38+
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs15;
3939
; CHECK-NEXT: st.param.b16 [func_retval0], %rs17;
4040
; CHECK-NEXT: ret;
4141
%res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
@@ -174,17 +174,17 @@ define half @reduce_fmul_half(<8 x half> %in) {
174174
; CHECK-EMPTY:
175175
; CHECK-NEXT: // %bb.0:
176176
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_param_0];
177-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
178-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
179-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
180-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
181-
; CHECK-NEXT: mul.rn.f16 %rs9, %rs7, %rs8;
182-
; CHECK-NEXT: mul.rn.f16 %rs10, %rs9, %rs5;
183-
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs6;
184-
; CHECK-NEXT: mul.rn.f16 %rs12, %rs11, %rs3;
185-
; CHECK-NEXT: mul.rn.f16 %rs13, %rs12, %rs4;
186-
; CHECK-NEXT: mul.rn.f16 %rs14, %rs13, %rs1;
187-
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs2;
177+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r2;
178+
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r1;
179+
; CHECK-NEXT: mul.rn.f16 %rs5, %rs3, %rs4;
180+
; CHECK-NEXT: mul.rn.f16 %rs6, %rs5, %rs1;
181+
; CHECK-NEXT: mul.rn.f16 %rs7, %rs6, %rs2;
182+
; CHECK-NEXT: mov.b32 {%rs8, %rs9}, %r3;
183+
; CHECK-NEXT: mul.rn.f16 %rs10, %rs7, %rs8;
184+
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs9;
185+
; CHECK-NEXT: mov.b32 {%rs12, %rs13}, %r4;
186+
; CHECK-NEXT: mul.rn.f16 %rs14, %rs11, %rs12;
187+
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs13;
188188
; CHECK-NEXT: st.param.b16 [func_retval0], %rs15;
189189
; CHECK-NEXT: ret;
190190
%res = call half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)

0 commit comments

Comments
 (0)