Skip to content

Commit f8d09af

Browse files
committed
[NVPTX] support VECREDUCE_SEQ ops and remove option
1 parent 6ac46e5 commit f8d09af

File tree

2 files changed

+44
-36
lines changed

2 files changed

+44
-36
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -85,12 +85,6 @@ static cl::opt<unsigned> FMAContractLevelOpt(
8585
" 1: do it 2: do it aggressively"),
8686
cl::init(2));
8787

88-
static cl::opt<bool> DisableFOpTreeReduce(
89-
"nvptx-disable-fop-tree-reduce", cl::Hidden,
90-
cl::desc("NVPTX Specific: don't emit tree reduction for floating-point "
91-
"reduction operations"),
92-
cl::init(false));
93-
9488
static cl::opt<int> UsePrecDivF32(
9589
"nvptx-prec-divf32", cl::Hidden,
9690
cl::desc("NVPTX Specifies: 0 use div.approx, 1 use div.full, 2 use"
@@ -847,6 +841,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
847841
if (EltVT == MVT::f16 || EltVT == MVT::bf16 || EltVT == MVT::f32 ||
848842
EltVT == MVT::f64) {
849843
setOperationAction({ISD::VECREDUCE_FADD, ISD::VECREDUCE_FMUL,
844+
ISD::VECREDUCE_SEQ_FADD, ISD::VECREDUCE_SEQ_FMUL,
850845
ISD::VECREDUCE_FMAX, ISD::VECREDUCE_FMIN,
851846
ISD::VECREDUCE_FMAXIMUM, ISD::VECREDUCE_FMINIMUM},
852847
VT, Custom);
@@ -2223,12 +2218,19 @@ static SDValue BuildTreeReduction(
22232218
/// max3/min3 when the target supports them.
22242219
SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22252220
SelectionDAG &DAG) const {
2226-
if (DisableFOpTreeReduce)
2227-
return SDValue();
2228-
22292221
SDLoc DL(Op);
22302222
const SDNodeFlags Flags = Op->getFlags();
2231-
const SDValue &Vector = Op.getOperand(0);
2223+
SDValue Vector;
2224+
SDValue Accumulator;
2225+
if (Op->getOpcode() == ISD::VECREDUCE_SEQ_FADD ||
2226+
Op->getOpcode() == ISD::VECREDUCE_SEQ_FMUL) {
2227+
// special case with accumulator as first arg
2228+
Accumulator = Op.getOperand(0);
2229+
Vector = Op.getOperand(1);
2230+
} else {
2231+
// default case
2232+
Vector = Op.getOperand(0);
2233+
}
22322234
EVT EltTy = Vector.getValueType().getVectorElementType();
22332235
const bool CanUseMinMax3 = EltTy == MVT::f32 && STI.getSmVersion() >= 100 &&
22342236
STI.getPTXVersion() >= 88;
@@ -2240,10 +2242,12 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
22402242

22412243
switch (Op->getOpcode()) {
22422244
case ISD::VECREDUCE_FADD:
2245+
case ISD::VECREDUCE_SEQ_FADD:
22432246
ScalarOps = {{ISD::FADD, 2}};
22442247
IsReassociatable = false;
22452248
break;
22462249
case ISD::VECREDUCE_FMUL:
2250+
case ISD::VECREDUCE_SEQ_FMUL:
22472251
ScalarOps = {{ISD::FMUL, 2}};
22482252
IsReassociatable = false;
22492253
break;
@@ -2322,11 +2326,13 @@ SDValue NVPTXTargetLowering::LowerVECREDUCE(SDValue Op,
23222326
}
23232327

23242328
// Lower to tree reduction.
2325-
if (IsReassociatable || Flags.hasAllowReassociation())
2329+
if (IsReassociatable || Flags.hasAllowReassociation()) {
2330+
// we don't expect an accumulator for reassociatable vector reduction ops
2331+
assert(!Accumulator && "unexpected accumulator");
23262332
return BuildTreeReduction(Elements, EltTy, ScalarOps, DL, Flags, DAG);
2333+
}
23272334

23282335
// Lower to sequential reduction.
2329-
SDValue Accumulator;
23302336
for (unsigned OpIdx = 0, I = 0; I < NumElts; ++OpIdx) {
23312337
assert(OpIdx < ScalarOps.size() && "no smaller operators for reduction");
23322338
const auto [DefaultScalarOp, DefaultGroupSize] = ScalarOps[OpIdx];
@@ -3143,6 +3149,8 @@ NVPTXTargetLowering::LowerOperation(SDValue Op, SelectionDAG &DAG) const {
31433149
return LowerCONCAT_VECTORS(Op, DAG);
31443150
case ISD::VECREDUCE_FADD:
31453151
case ISD::VECREDUCE_FMUL:
3152+
case ISD::VECREDUCE_SEQ_FADD:
3153+
case ISD::VECREDUCE_SEQ_FMUL:
31463154
case ISD::VECREDUCE_FMAX:
31473155
case ISD::VECREDUCE_FMIN:
31483156
case ISD::VECREDUCE_FMAXIMUM:

llvm/test/CodeGen/NVPTX/reduction-intrinsics.ll

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,19 @@ define half @reduce_fadd_half(<8 x half> %in) {
2323
; CHECK-EMPTY:
2424
; CHECK-NEXT: // %bb.0:
2525
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fadd_half_param_0];
26-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
27-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
28-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
29-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
30-
; CHECK-NEXT: mov.b16 %rs9, 0x0000;
31-
; CHECK-NEXT: add.rn.f16 %rs10, %rs7, %rs9;
32-
; CHECK-NEXT: add.rn.f16 %rs11, %rs10, %rs8;
33-
; CHECK-NEXT: add.rn.f16 %rs12, %rs11, %rs5;
34-
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs6;
35-
; CHECK-NEXT: add.rn.f16 %rs14, %rs13, %rs3;
36-
; CHECK-NEXT: add.rn.f16 %rs15, %rs14, %rs4;
37-
; CHECK-NEXT: add.rn.f16 %rs16, %rs15, %rs1;
38-
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs2;
26+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r1;
27+
; CHECK-NEXT: mov.b16 %rs3, 0x0000;
28+
; CHECK-NEXT: add.rn.f16 %rs4, %rs1, %rs3;
29+
; CHECK-NEXT: add.rn.f16 %rs5, %rs4, %rs2;
30+
; CHECK-NEXT: mov.b32 {%rs6, %rs7}, %r2;
31+
; CHECK-NEXT: add.rn.f16 %rs8, %rs5, %rs6;
32+
; CHECK-NEXT: add.rn.f16 %rs9, %rs8, %rs7;
33+
; CHECK-NEXT: mov.b32 {%rs10, %rs11}, %r3;
34+
; CHECK-NEXT: add.rn.f16 %rs12, %rs9, %rs10;
35+
; CHECK-NEXT: add.rn.f16 %rs13, %rs12, %rs11;
36+
; CHECK-NEXT: mov.b32 {%rs14, %rs15}, %r4;
37+
; CHECK-NEXT: add.rn.f16 %rs16, %rs13, %rs14;
38+
; CHECK-NEXT: add.rn.f16 %rs17, %rs16, %rs15;
3939
; CHECK-NEXT: st.param.b16 [func_retval0], %rs17;
4040
; CHECK-NEXT: ret;
4141
%res = call half @llvm.vector.reduce.fadd(half 0.0, <8 x half> %in)
@@ -174,17 +174,17 @@ define half @reduce_fmul_half(<8 x half> %in) {
174174
; CHECK-EMPTY:
175175
; CHECK-NEXT: // %bb.0:
176176
; CHECK-NEXT: ld.param.v4.u32 {%r1, %r2, %r3, %r4}, [reduce_fmul_half_param_0];
177-
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r4;
178-
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r3;
179-
; CHECK-NEXT: mov.b32 {%rs5, %rs6}, %r2;
180-
; CHECK-NEXT: mov.b32 {%rs7, %rs8}, %r1;
181-
; CHECK-NEXT: mul.rn.f16 %rs9, %rs7, %rs8;
182-
; CHECK-NEXT: mul.rn.f16 %rs10, %rs9, %rs5;
183-
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs6;
184-
; CHECK-NEXT: mul.rn.f16 %rs12, %rs11, %rs3;
185-
; CHECK-NEXT: mul.rn.f16 %rs13, %rs12, %rs4;
186-
; CHECK-NEXT: mul.rn.f16 %rs14, %rs13, %rs1;
187-
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs2;
177+
; CHECK-NEXT: mov.b32 {%rs1, %rs2}, %r2;
178+
; CHECK-NEXT: mov.b32 {%rs3, %rs4}, %r1;
179+
; CHECK-NEXT: mul.rn.f16 %rs5, %rs3, %rs4;
180+
; CHECK-NEXT: mul.rn.f16 %rs6, %rs5, %rs1;
181+
; CHECK-NEXT: mul.rn.f16 %rs7, %rs6, %rs2;
182+
; CHECK-NEXT: mov.b32 {%rs8, %rs9}, %r3;
183+
; CHECK-NEXT: mul.rn.f16 %rs10, %rs7, %rs8;
184+
; CHECK-NEXT: mul.rn.f16 %rs11, %rs10, %rs9;
185+
; CHECK-NEXT: mov.b32 {%rs12, %rs13}, %r4;
186+
; CHECK-NEXT: mul.rn.f16 %rs14, %rs11, %rs12;
187+
; CHECK-NEXT: mul.rn.f16 %rs15, %rs14, %rs13;
188188
; CHECK-NEXT: st.param.b16 [func_retval0], %rs15;
189189
; CHECK-NEXT: ret;
190190
%res = call half @llvm.vector.reduce.fmul(half 1.0, <8 x half> %in)

0 commit comments

Comments
 (0)