Skip to content

Commit 1f1348c

Browse files
committed
[NVPTX] support VECREDUCE_SEQ ops and remove option
1 parent 7f5440b commit 1f1348c

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<bool> DisableFOpTreeReduce(
89-
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90-
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91-
"reduction operations"),
92-
cl::init(false));
93-
9488
static cl::opt<int> UsePrecDivF32(
9589
"nvptx-prec-divf32", cl::Hidden,
9690
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -844,6 +838,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
844838
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
845839
EltVT == MVT::f64) {
846840
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
841+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
847842
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
848843
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
849844
VT, Custom);
@@ -2212,12 +2207,19 @@ static SDValue BuildTreeReduction(
22122207
/// max3/min3 when the target supports them.
22132208
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22142209
SelectionDAG &DAG) const {
2215-
if (DisableFOpTreeReduce)
2216-
return SDValue();
2217-
22182210
SDLoc DL(Op);
22192211
const SDNodeFlags Flags = Op->getFlags();
2220-
const SDValue &Vector = Op.getOperand(0);
2212+
SDValue Vector;
2213+
SDValue Accumulator;
2214+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2215+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2216+
// special case with accumulator as first arg
2217+
Accumulator = Op.getOperand(0);
2218+
Vector = Op.getOperand(1);
2219+
} else {
2220+
// default case
2221+
Vector = Op.getOperand(0);
2222+
}
22212223
EVT EltTy = Vector.getValueType().getVectorElementType();
22222224
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22232225
STI.getPTXVersion() >= 88;
@@ -2229,10 +2231,12 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22292231

22302232
switch (Op->getOpcode()) {
22312233
case ISD::VECREDUCE_FADD:
2234+
case ISD::VECREDUCE_SEQ_FADD:
22322235
ScalarOps = {{ISD::FADD, 2}};
22332236
IsReassociatable = false;
22342237
break;
22352238
case ISD::VECREDUCE_FMUL:
2239+
case ISD::VECREDUCE_SEQ_FMUL:
22362240
ScalarOps = {{ISD::FMUL, 2}};
22372241
IsReassociatable = false;
22382242
break;
@@ -2311,11 +2315,13 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23112315
}
23122316

23132317
// Lower to tree reduction.
2314-
if (IsReassociatable || Flags.hasAllowReassociation())
2318+
if (IsReassociatable || Flags.hasAllowReassociation()) {
2319+
// we don't expect an accumulator for reassociatable vector reduction ops
2320+
assert(!Accumulator && "unexpected accumulator");
23152321
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2322+
}
23162323

23172324
// Lower to sequential reduction.
2318-
SDValue Accumulator;
23192325
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
23202326
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23212327
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
@@ -3087,6 +3093,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
30873093
return LowerCONCAT_VECTORS(Op, DAG);
30883094
case ISD::VECREDUCE_FADD:
30893095
case ISD::VECREDUCE_FMUL:
3096+
case ISD::VECREDUCE_SEQ_FADD:
3097+
case ISD::VECREDUCE_SEQ_FMUL:
30903098
case ISD::VECREDUCE_FMAX:
30913099
case ISD::VECREDUCE_FMIN:
30923100
case ISD::VECREDUCE_FMAXIMUM:

llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ define half @reduce_fadd_half(<8 x half> %in) {
2323
; CHECK-EMPTY:
2424
; CHECK-NEXT: // %bb.0:
2525
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
26-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
27-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
28-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
29-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
30-
; CHECK-NEXT: mov.b16 %rs9, 0x0000;
31-
; CHECK-NEXT: add.rn.f16 %rs10, %rs7, %rs9;
32-
; CHECK-NEXT: add.rn.f16 %rs11, %rs10, %rs8;
33-
; CHECK-NEXT: add.rn.f16 %rs12, %rs11, %rs5;
34-
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs6;
35-
; CHECK-NEXT: add.rn.f16 %rs14, %rs13, %rs3;
36-
; CHECK-NEXT: add.rn.f16 %rs15, %rs14, %rs4;
37-
; CHECK-NEXT: add.rn.f16 %rs16, %rs15, %rs1;
38-
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs2;
26+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
27+
; CHECK-NEXT: mov.b16 %rs3, 0x0000;
28+
; CHECK-NEXT: add.rn.f16 %rs4, %rs1, %rs3;
29+
; CHECK-NEXT: add.rn.f16 %rs5, %rs4, %rs2;
30+
; CHECK-NEXT: mov.b32 {%rs6, %rs7}, %r2;
31+
; CHECK-NEXT: add.rn.f16 %rs8, %rs5, %rs6;
32+
; CHECK-NEXT: add.rn.f16 %rs9, %rs8, %rs7;
33+
; CHECK-NEXT: mov.b32 {%rs10, %rs11}, %r3;
34+
; CHECK-NEXT: add.rn.f16 %rs12, %rs9, %rs10;
35+
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs11;
36+
; CHECK-NEXT: mov.b32 {%rs14, %rs15}, %r4;
37+
; CHECK-NEXT: add.rn.f16 %rs16, %rs13, %rs14;
38+
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs15;
3939
; CHECK-NEXT: st.param.b16 [func_retval0], %rs17;
4040
; CHECK-NEXT: ret;
4141
%res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
@@ -174,17 +174,17 @@ define half @reduce_fmul_half(<8 x half> %in) {
174174
; CHECK-EMPTY:
175175
; CHECK-NEXT: // %bb.0:
176176
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_param_0];
177-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
178-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
179-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
180-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
181-
; CHECK-NEXT: mul.rn.f16 %rs9, %rs7, %rs8;
182-
; CHECK-NEXT: mul.rn.f16 %rs10, %rs9, %rs5;
183-
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs6;
184-
; CHECK-NEXT: mul.rn.f16 %rs12, %rs11, %rs3;
185-
; CHECK-NEXT: mul.rn.f16 %rs13, %rs12, %rs4;
186-
; CHECK-NEXT: mul.rn.f16 %rs14, %rs13, %rs1;
187-
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs2;
177+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r2;
178+
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r1;
179+
; CHECK-NEXT: mul.rn.f16 %rs5, %rs3, %rs4;
180+
; CHECK-NEXT: mul.rn.f16 %rs6, %rs5, %rs1;
181+
; CHECK-NEXT: mul.rn.f16 %rs7, %rs6, %rs2;
182+
; CHECK-NEXT: mov.b32 {%rs8, %rs9}, %r3;
183+
; CHECK-NEXT: mul.rn.f16 %rs10, %rs7, %rs8;
184+
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs9;
185+
; CHECK-NEXT: mov.b32 {%rs12, %rs13}, %r4;
186+
; CHECK-NEXT: mul.rn.f16 %rs14, %rs11, %rs12;
187+
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs13;
188188
; CHECK-NEXT: st.param.b16 [func_retval0], %rs15;
189189
; CHECK-NEXT: ret;
190190
%res = call half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)

0 commit comments

Comments
 (0)