Skip to content

Commit b7bf96a

Browse files
committed
[LegalizeTypes][VP] Add widening support for vp.reduce.*
When widening these intrinsics, we do not have to insert neutral elements at the end of the vector as when widening vector.reduce.* intrinsics, thanks to vector predication semantics. Reviewed By: craig.topper Differential Revision: https://reviews.llvm.org/D117467
1 parent 0861fbe commit b7bf96a

File tree

8 files changed

+210
-51
lines changed

8 files changed

+210
-51
lines changed

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,23 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
905905
}
906906
void SetWidenedVector(SDValue Op, SDValue Result);
907907

908+
/// Given a mask Mask, returns the larger vector into which Mask was widened.
909+
SDValue GetWidenedMask(SDValue Mask, ElementCount EC) {
910+
// For VP operations, we must also widen the mask. Note that the mask type
911+
// may not actually need widening, leading it be split along with the VP
912+
// operation.
913+
// FIXME: This could lead to an infinite split/widen loop. We only handle
914+
// the case where the mask needs widening to an identically-sized type as
915+
// the vector inputs.
916+
assert(getTypeAction(Mask.getValueType()) ==
917+
TargetLowering::TypeWidenVector &&
918+
"Unable to widen binary VP op");
919+
Mask = GetWidenedVector(Mask);
920+
assert(Mask.getValueType().getVectorElementCount() == EC &&
921+
"Unable to widen binary VP op");
922+
return Mask;
923+
}
924+
908925
// Widen Vector Result Promotion.
909926
void WidenVectorResult(SDNode *N, unsigned ResNo);
910927
SDValue WidenVecRes_MERGE_VALUES(SDNode* N, unsigned ResNo);
@@ -964,6 +981,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
964981
SDValue WidenVecOp_FCOPYSIGN(SDNode *N);
965982
SDValue WidenVecOp_VECREDUCE(SDNode *N);
966983
SDValue WidenVecOp_VECREDUCE_SEQ(SDNode *N);
984+
SDValue WidenVecOp_VP_REDUCE(SDNode *N);
967985

968986
/// Helper function to generate a set of operations to perform
969987
/// a vector operation for a wider type.

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3445,20 +3445,8 @@ SDValue DAGTypeLegalizer::WidenVecRes_Binary(SDNode *N) {
34453445
assert(N->getNumOperands() == 4 && "Unexpected number of operands!");
34463446
assert(N->isVPOpcode() && "Expected VP opcode");
34473447

3448-
// For VP operations, we must also widen the mask. Note that the mask type
3449-
// may not actually need widening, leading it be split along with the VP
3450-
// operation.
3451-
// FIXME: This could lead to an infinite split/widen loop. We only handle the
3452-
// case where the mask needs widening to an identically-sized type as the
3453-
// vector inputs.
3454-
SDValue Mask = N->getOperand(2);
3455-
assert(getTypeAction(Mask.getValueType()) ==
3456-
TargetLowering::TypeWidenVector &&
3457-
"Unable to widen binary VP op");
3458-
Mask = GetWidenedVector(Mask);
3459-
assert(Mask.getValueType().getVectorElementCount() ==
3460-
WidenVT.getVectorElementCount() &&
3461-
"Unable to widen binary VP op");
3448+
SDValue Mask =
3449+
GetWidenedMask(N->getOperand(2), WidenVT.getVectorElementCount());
34623450
return DAG.getNode(N->getOpcode(), dl, WidenVT,
34633451
{InOp1, InOp2, Mask, N->getOperand(3)}, N->getFlags());
34643452
}
@@ -4978,6 +4966,23 @@ bool DAGTypeLegalizer::WidenVectorOperand(SDNode *N, unsigned OpNo) {
49784966
case ISD::VECREDUCE_SEQ_FMUL:
49794967
Res = WidenVecOp_VECREDUCE_SEQ(N);
49804968
break;
4969+
case ISD::VP_REDUCE_FADD:
4970+
case ISD::VP_REDUCE_SEQ_FADD:
4971+
case ISD::VP_REDUCE_FMUL:
4972+
case ISD::VP_REDUCE_SEQ_FMUL:
4973+
case ISD::VP_REDUCE_ADD:
4974+
case ISD::VP_REDUCE_MUL:
4975+
case ISD::VP_REDUCE_AND:
4976+
case ISD::VP_REDUCE_OR:
4977+
case ISD::VP_REDUCE_XOR:
4978+
case ISD::VP_REDUCE_SMAX:
4979+
case ISD::VP_REDUCE_SMIN:
4980+
case ISD::VP_REDUCE_UMAX:
4981+
case ISD::VP_REDUCE_UMIN:
4982+
case ISD::VP_REDUCE_FMAX:
4983+
case ISD::VP_REDUCE_FMIN:
4984+
Res = WidenVecOp_VP_REDUCE(N);
4985+
break;
49814986
}
49824987

49834988
// If Res is null, the sub-method took care of registering the result.
@@ -5571,6 +5576,19 @@ SDValue DAGTypeLegalizer::WidenVecOp_VECREDUCE_SEQ(SDNode *N) {
55715576
return DAG.getNode(Opc, dl, N->getValueType(0), AccOp, Op, Flags);
55725577
}
55735578

5579+
SDValue DAGTypeLegalizer::WidenVecOp_VP_REDUCE(SDNode *N) {
5580+
assert(N->isVPOpcode() && "Expected VP opcode");
5581+
5582+
SDLoc dl(N);
5583+
SDValue Op = GetWidenedVector(N->getOperand(1));
5584+
SDValue Mask = GetWidenedMask(N->getOperand(2),
5585+
Op.getValueType().getVectorElementCount());
5586+
5587+
return DAG.getNode(N->getOpcode(), dl, N->getValueType(0),
5588+
{N->getOperand(0), Op, Mask, N->getOperand(3)},
5589+
N->getFlags());
5590+
}
5591+
55745592
SDValue DAGTypeLegalizer::WidenVecOp_VSELECT(SDNode *N) {
55755593
// This only gets called in the case that the left and right inputs and
55765594
// result are of a legal odd vector type, and the condition is illegal i1 of

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-fp-vp.ll

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,34 @@ define double @vpreduce_ord_fadd_v2f64(double %s, <2 x double> %v, <2 x i1> %m,
210210
ret double %r
211211
}
212212

213+
declare double @llvm.vp.reduce.fadd.v3f64(double, <3 x double>, <3 x i1>, i32)
214+
215+
define double @vpreduce_fadd_v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 zeroext %evl) {
216+
; CHECK-LABEL: vpreduce_fadd_v3f64:
217+
; CHECK: # %bb.0:
218+
; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu
219+
; CHECK-NEXT: vfmv.s.f v10, fa0
220+
; CHECK-NEXT: vsetvli zero, a0, e64, m2, tu, mu
221+
; CHECK-NEXT: vfredusum.vs v10, v8, v10, v0.t
222+
; CHECK-NEXT: vfmv.f.s fa0, v10
223+
; CHECK-NEXT: ret
224+
%r = call reassoc double @llvm.vp.reduce.fadd.v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 %evl)
225+
ret double %r
226+
}
227+
228+
define double @vpreduce_ord_fadd_v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 zeroext %evl) {
229+
; CHECK-LABEL: vpreduce_ord_fadd_v3f64:
230+
; CHECK: # %bb.0:
231+
; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu
232+
; CHECK-NEXT: vfmv.s.f v10, fa0
233+
; CHECK-NEXT: vsetvli zero, a0, e64, m2, tu, mu
234+
; CHECK-NEXT: vfredosum.vs v10, v8, v10, v0.t
235+
; CHECK-NEXT: vfmv.f.s fa0, v10
236+
; CHECK-NEXT: ret
237+
%r = call double @llvm.vp.reduce.fadd.v3f64(double %s, <3 x double> %v, <3 x i1> %m, i32 %evl)
238+
ret double %r
239+
}
240+
213241
declare double @llvm.vp.reduce.fadd.v4f64(double, <4 x double>, <4 x i1>, i32)
214242

215243
define double @vpreduce_fadd_v4f64(double %s, <4 x double> %v, <4 x i1> %m, i32 zeroext %evl) {

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-int-vp.ll

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,22 @@ define signext i8 @vpreduce_xor_v2i8(i8 signext %s, <2 x i8> %v, <2 x i1> %m, i3
126126
ret i8 %r
127127
}
128128

129+
declare i8 @llvm.vp.reduce.umin.v3i8(i8, <3 x i8>, <3 x i1>, i32)
130+
131+
define signext i8 @vpreduce_umin_v3i8(i8 signext %s, <3 x i8> %v, <3 x i1> %m, i32 zeroext %evl) {
132+
; CHECK-LABEL: vpreduce_umin_v3i8:
133+
; CHECK: # %bb.0:
134+
; CHECK-NEXT: andi a0, a0, 255
135+
; CHECK-NEXT: vsetivli zero, 1, e8, m1, ta, mu
136+
; CHECK-NEXT: vmv.s.x v9, a0
137+
; CHECK-NEXT: vsetvli zero, a1, e8, mf4, tu, mu
138+
; CHECK-NEXT: vredminu.vs v9, v8, v9, v0.t
139+
; CHECK-NEXT: vmv.x.s a0, v9
140+
; CHECK-NEXT: ret
141+
%r = call i8 @llvm.vp.reduce.umin.v3i8(i8 %s, <3 x i8> %v, <3 x i1> %m, i32 %evl)
142+
ret i8 %r
143+
}
144+
129145
declare i8 @llvm.vp.reduce.add.v4i8(i8, <4 x i8>, <4 x i1>, i32)
130146

131147
define signext i8 @vpreduce_add_v4i8(i8 signext %s, <4 x i8> %v, <4 x i1> %m, i32 zeroext %evl) {
@@ -831,17 +847,17 @@ define signext i32 @vpreduce_xor_v64i32(i32 signext %s, <64 x i32> %v, <64 x i1>
831847
; CHECK: # %bb.0:
832848
; CHECK-NEXT: addi a3, a1, -32
833849
; CHECK-NEXT: li a2, 0
834-
; CHECK-NEXT: bltu a1, a3, .LBB48_2
850+
; CHECK-NEXT: bltu a1, a3, .LBB49_2
835851
; CHECK-NEXT: # %bb.1:
836852
; CHECK-NEXT: mv a2, a3
837-
; CHECK-NEXT: .LBB48_2:
853+
; CHECK-NEXT: .LBB49_2:
838854
; CHECK-NEXT: vsetivli zero, 4, e8, mf2, ta, mu
839855
; CHECK-NEXT: li a3, 32
840856
; CHECK-NEXT: vslidedown.vi v24, v0, 4
841-
; CHECK-NEXT: bltu a1, a3, .LBB48_4
857+
; CHECK-NEXT: bltu a1, a3, .LBB49_4
842858
; CHECK-NEXT: # %bb.3:
843859
; CHECK-NEXT: li a1, 32
844-
; CHECK-NEXT: .LBB48_4:
860+
; CHECK-NEXT: .LBB49_4:
845861
; CHECK-NEXT: vsetivli zero, 1, e32, m1, ta, mu
846862
; CHECK-NEXT: vmv.s.x v25, a0
847863
; CHECK-NEXT: vsetvli zero, a1, e32, m8, tu, mu

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-reduction-mask-vp.ll

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,23 @@ define signext i1 @vpreduce_xor_v8i1(i1 signext %s, <8 x i1> %v, <8 x i1> %m, i3
212212
ret i1 %r
213213
}
214214

215+
declare i1 @llvm.vp.reduce.and.v10i1(i1, <10 x i1>, <10 x i1>, i32)
216+
217+
define signext i1 @vpreduce_and_v10i1(i1 signext %s, <10 x i1> %v, <10 x i1> %m, i32 zeroext %evl) {
218+
; CHECK-LABEL: vpreduce_and_v10i1:
219+
; CHECK: # %bb.0:
220+
; CHECK-NEXT: vsetvli zero, a1, e8, m1, ta, mu
221+
; CHECK-NEXT: vmnand.mm v9, v0, v0
222+
; CHECK-NEXT: vmv1r.v v0, v8
223+
; CHECK-NEXT: vcpop.m a1, v9, v0.t
224+
; CHECK-NEXT: seqz a1, a1
225+
; CHECK-NEXT: and a0, a1, a0
226+
; CHECK-NEXT: neg a0, a0
227+
; CHECK-NEXT: ret
228+
%r = call i1 @llvm.vp.reduce.and.v10i1(i1 %s, <10 x i1> %v, <10 x i1> %m, i32 %evl)
229+
ret i1 %r
230+
}
231+
215232
declare i1 @llvm.vp.reduce.and.v16i1(i1, <16 x i1>, <16 x i1>, i32)
216233

217234
define signext i1 @vpreduce_and_v16i1(i1 signext %s, <16 x i1> %v, <16 x i1> %m, i32 zeroext %evl) {
@@ -237,20 +254,20 @@ define signext i1 @vpreduce_and_v256i1(i1 signext %s, <256 x i1> %v, <256 x i1>
237254
; CHECK-NEXT: addi a2, a1, -128
238255
; CHECK-NEXT: vmv1r.v v11, v0
239256
; CHECK-NEXT: li a3, 0
240-
; CHECK-NEXT: bltu a1, a2, .LBB13_2
257+
; CHECK-NEXT: bltu a1, a2, .LBB14_2
241258
; CHECK-NEXT: # %bb.1:
242259
; CHECK-NEXT: mv a3, a2
243-
; CHECK-NEXT: .LBB13_2:
260+
; CHECK-NEXT: .LBB14_2:
244261
; CHECK-NEXT: vsetvli zero, a3, e8, m8, ta, mu
245262
; CHECK-NEXT: vmnand.mm v8, v8, v8
246263
; CHECK-NEXT: vmv1r.v v0, v10
247264
; CHECK-NEXT: vcpop.m a2, v8, v0.t
248265
; CHECK-NEXT: li a3, 128
249266
; CHECK-NEXT: seqz a2, a2
250-
; CHECK-NEXT: bltu a1, a3, .LBB13_4
267+
; CHECK-NEXT: bltu a1, a3, .LBB14_4
251268
; CHECK-NEXT: # %bb.3:
252269
; CHECK-NEXT: li a1, 128
253-
; CHECK-NEXT: .LBB13_4:
270+
; CHECK-NEXT: .LBB14_4:
254271
; CHECK-NEXT: vsetvli zero, a1, e8, m8, ta, mu
255272
; CHECK-NEXT: vmnand.mm v8, v11, v11
256273
; CHECK-NEXT: vmv1r.v v0, v9

llvm/test/CodeGen/RISCV/rvv/vreductions-fp-vp.ll

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,35 @@ define double @vpreduce_ord_fadd_nxv2f64(double %s, <vscale x 2 x double> %v, <v
300300
ret double %r
301301
}
302302

303+
declare double @llvm.vp.reduce.fadd.nxv3f64(double, <vscale x 3 x double>, <vscale x 3 x i1>, i32)
304+
305+
define double @vpreduce_fadd_nxv3f64(double %s, <vscale x 3 x double> %v, <vscale x 3 x i1> %m, i32 zeroext %evl) {
306+
; CHECK-LABEL: vpreduce_fadd_nxv3f64:
307+
; CHECK: # %bb.0:
308+
; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu
309+
; CHECK-NEXT: vfmv.s.f v12, fa0
310+
; CHECK-NEXT: vsetvli zero, a0, e64, m4, tu, mu
311+
; CHECK-NEXT: vfredusum.vs v12, v8, v12, v0.t
312+
; CHECK-NEXT: vfmv.f.s fa0, v12
313+
; CHECK-NEXT: ret
314+
%r = call reassoc double @llvm.vp.reduce.fadd.nxv3f64(double %s, <vscale x 3 x double> %v, <vscale x 3 x i1> %m, i32 %evl)
315+
ret double %r
316+
}
317+
318+
define double @vpreduce_ord_fadd_nxv3f64(double %s, <vscale x 4 x double> %v, <vscale x 4 x i1> %m, i32 zeroext %evl) {
319+
; CHECK-LABEL: vpreduce_ord_fadd_nxv3f64:
320+
; CHECK: # %bb.0:
321+
; CHECK-NEXT: vsetivli zero, 1, e64, m1, ta, mu
322+
; CHECK-NEXT: vfmv.s.f v12, fa0
323+
; CHECK-NEXT: vsetvli zero, a0, e64, m4, tu, mu
324+
; CHECK-NEXT: vfredosum.vs v12, v8, v12, v0.t
325+
; CHECK-NEXT: vfmv.f.s fa0, v12
326+
; CHECK-NEXT: ret
327+
%r = call double @llvm.vp.reduce.fadd.nxv4f64(double %s, <vscale x 4 x double> %v, <vscale x 4 x i1> %m, i32 %evl)
328+
ret double %r
329+
}
330+
331+
303332
declare double @llvm.vp.reduce.fadd.nxv4f64(double, <vscale x 4 x double>, <vscale x 4 x i1>, i32)
304333

305334
define double @vpreduce_fadd_nxv4f64(double %s, <vscale x 4 x double> %v, <vscale x 4 x i1> %m, i32 zeroext %evl) {

llvm/test/CodeGen/RISCV/rvv/vreductions-int-vp.ll

Lines changed: 23 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,21 @@ define signext i8 @vpreduce_xor_nxv2i8(i8 signext %s, <vscale x 2 x i8> %v, <vsc
248248
ret i8 %r
249249
}
250250

251+
declare i8 @llvm.vp.reduce.smax.nxv3i8(i8, <vscale x 3 x i8>, <vscale x 3 x i1>, i32)
252+
253+
define signext i8 @vpreduce_smax_nxv3i8(i8 signext %s, <vscale x 3 x i8> %v, <vscale x 3 x i1> %m, i32 zeroext %evl) {
254+
; CHECK-LABEL: vpreduce_smax_nxv3i8:
255+
; CHECK: # %bb.0:
256+
; CHECK-NEXT: vsetivli zero, 1, e8, m1, ta, mu
257+
; CHECK-NEXT: vmv.s.x v9, a0
258+
; CHECK-NEXT: vsetvli zero, a1, e8, mf2, tu, mu
259+
; CHECK-NEXT: vredmax.vs v9, v8, v9, v0.t
260+
; CHECK-NEXT: vmv.x.s a0, v9
261+
; CHECK-NEXT: ret
262+
%r = call i8 @llvm.vp.reduce.smax.nxv3i8(i8 %s, <vscale x 3 x i8> %v, <vscale x 3 x i1> %m, i32 %evl)
263+
ret i8 %r
264+
}
265+
251266
declare i8 @llvm.vp.reduce.add.nxv4i8(i8, <vscale x 4 x i8>, <vscale x 4 x i1>, i32)
252267

253268
define signext i8 @vpreduce_add_nxv4i8(i8 signext %s, <vscale x 4 x i8> %v, <vscale x 4 x i1> %m, i32 zeroext %evl) {
@@ -1144,10 +1159,10 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
11441159
; RV32-NEXT: slli a3, a3, 1
11451160
; RV32-NEXT: vmv.s.x v25, a0
11461161
; RV32-NEXT: mv a0, a1
1147-
; RV32-NEXT: bltu a1, a3, .LBB66_2
1162+
; RV32-NEXT: bltu a1, a3, .LBB67_2
11481163
; RV32-NEXT: # %bb.1:
11491164
; RV32-NEXT: mv a0, a3
1150-
; RV32-NEXT: .LBB66_2:
1165+
; RV32-NEXT: .LBB67_2:
11511166
; RV32-NEXT: li a4, 0
11521167
; RV32-NEXT: vsetvli a5, zero, e8, mf2, ta, mu
11531168
; RV32-NEXT: vslidedown.vx v24, v0, a2
@@ -1157,10 +1172,10 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
11571172
; RV32-NEXT: vsetivli zero, 1, e32, m1, ta, mu
11581173
; RV32-NEXT: sub a0, a1, a3
11591174
; RV32-NEXT: vmv.s.x v8, a2
1160-
; RV32-NEXT: bltu a1, a0, .LBB66_4
1175+
; RV32-NEXT: bltu a1, a0, .LBB67_4
11611176
; RV32-NEXT: # %bb.3:
11621177
; RV32-NEXT: mv a4, a0
1163-
; RV32-NEXT: .LBB66_4:
1178+
; RV32-NEXT: .LBB67_4:
11641179
; RV32-NEXT: vsetvli zero, a4, e32, m8, tu, mu
11651180
; RV32-NEXT: vmv1r.v v0, v24
11661181
; RV32-NEXT: vredmaxu.vs v8, v16, v8, v0.t
@@ -1175,10 +1190,10 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
11751190
; RV64-NEXT: slli a0, a3, 1
11761191
; RV64-NEXT: srli a3, a4, 32
11771192
; RV64-NEXT: mv a4, a1
1178-
; RV64-NEXT: bltu a1, a0, .LBB66_2
1193+
; RV64-NEXT: bltu a1, a0, .LBB67_2
11791194
; RV64-NEXT: # %bb.1:
11801195
; RV64-NEXT: mv a4, a0
1181-
; RV64-NEXT: .LBB66_2:
1196+
; RV64-NEXT: .LBB67_2:
11821197
; RV64-NEXT: li a5, 0
11831198
; RV64-NEXT: vsetvli a2, zero, e8, mf2, ta, mu
11841199
; RV64-NEXT: vslidedown.vx v24, v0, a6
@@ -1190,10 +1205,10 @@ define signext i32 @vpreduce_umax_nxv32i32(i32 signext %s, <vscale x 32 x i32> %
11901205
; RV64-NEXT: vsetivli zero, 1, e32, m1, ta, mu
11911206
; RV64-NEXT: sub a0, a1, a0
11921207
; RV64-NEXT: vmv.s.x v8, a2
1193-
; RV64-NEXT: bltu a1, a0, .LBB66_4
1208+
; RV64-NEXT: bltu a1, a0, .LBB67_4
11941209
; RV64-NEXT: # %bb.3:
11951210
; RV64-NEXT: mv a5, a0
1196-
; RV64-NEXT: .LBB66_4:
1211+
; RV64-NEXT: .LBB67_4:
11971212
; RV64-NEXT: vsetvli zero, a5, e32, m8, tu, mu
11981213
; RV64-NEXT: vmv1r.v v0, v24
11991214
; RV64-NEXT: vredmaxu.vs v8, v16, v8, v0.t

0 commit comments

Comments
 (0)