Skip to content

Commit d1fe7a2

Browse files
[LLVM][DAGCombiner][SVE] Fold vselect into merge_pasthru_op. (#146917)
vselect A, (merge_pasthru_op all_active, B,{Bn,} -), C vselect A, (merge_pasthru_op -, B,{Bn,} undef), C vselect A, (merge_pasthru_op A, B,{Bn,} -), C -> merge_pasthru_op A, B,{Bn,} C
1 parent e55b194 commit d1fe7a2

File tree

3 files changed

+199
-347
lines changed

3 files changed

+199
-347
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25523,6 +25523,9 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
2552325523
return SwapResult;
2552425524

2552525525
SDValue N0 = N->getOperand(0);
25526+
SDValue IfTrue = N->getOperand(1);
25527+
SDValue IfFalse = N->getOperand(2);
25528+
EVT ResVT = N->getValueType(0);
2552625529
EVT CCVT = N0.getValueType();
2552725530

2552825531
if (isAllActivePredicate(DAG, N0))
@@ -25531,6 +25534,22 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
2553125534
if (isAllInactivePredicate(N0))
2553225535
return N->getOperand(2);
2553325536

25537+
if (isMergePassthruOpcode(IfTrue.getOpcode()) && IfTrue.hasOneUse()) {
25538+
// vselect A, (merge_pasthru_op all_active, B,{Bn,} -), C
25539+
// vselect A, (merge_pasthru_op -, B,{Bn,} undef), C
25540+
// vselect A, (merge_pasthru_op A, B,{Bn,} -), C
25541+
// -> merge_pasthru_op A, B,{Bn,} C
25542+
if (isAllActivePredicate(DAG, IfTrue->getOperand(0)) ||
25543+
IfTrue->getOperand(IfTrue.getNumOperands() - 1).isUndef() ||
25544+
IfTrue->getOperand(0) == N0) {
25545+
SmallVector<SDValue, 4> Ops(IfTrue->op_values());
25546+
Ops[0] = N0;
25547+
Ops[IfTrue.getNumOperands() - 1] = IfFalse;
25548+
25549+
return DAG.getNode(IfTrue.getOpcode(), SDLoc(N), ResVT, Ops);
25550+
}
25551+
}
25552+
2553425553
// Check for sign pattern (VSELECT setgt, iN lhs, -1, 1, -1) and transform
2553525554
// into (OR (ASR lhs, N-1), 1), which requires less instructions for the
2553625555
// supported types.
@@ -25570,14 +25589,11 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
2557025589
CmpVT.getVectorElementType().isFloatingPoint())
2557125590
return SDValue();
2557225591

25573-
EVT ResVT = N->getValueType(0);
2557425592
// Only combine when the result type is of the same size as the compared
2557525593
// operands.
2557625594
if (ResVT.getSizeInBits() != CmpVT.getSizeInBits())
2557725595
return SDValue();
2557825596

25579-
SDValue IfTrue = N->getOperand(1);
25580-
SDValue IfFalse = N->getOperand(2);
2558125597
SetCC = DAG.getSetCC(SDLoc(N), CmpVT.changeVectorElementTypeToInteger(),
2558225598
N0.getOperand(0), N0.getOperand(1),
2558325599
cast<CondCodeSDNode>(N0.getOperand(2))->get());

llvm/lib/Target/AArch64/AArch64SVEInstrInfo.td

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2445,14 +2445,23 @@ let Predicates = [HasSVE_or_SME] in {
24452445
defm FCVTZU_ZPmZ_DtoD : sve_fp_2op_p_zd< 0b1111111, "fcvtzu", ZPR64, ZPR64, null_frag, AArch64fcvtzu_mt, nxv2i64, nxv2i1, nxv2f64, ElementSizeD>;
24462446

24472447
//These patterns exist to improve the code quality of conversions on unpacked types.
2448+
def : Pat<(nxv2f32 (AArch64fcvte_mt nxv2i1:$Pg, nxv2f16:$Zs, nxv2f32:$Zd)),
2449+
(FCVT_ZPmZ_HtoS ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
24482450
def : Pat<(nxv2f32 (AArch64fcvte_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f16:$Zs, nxv2f32:$Zd)),
24492451
(FCVT_ZPmZ_HtoS_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
24502452

24512453
// FP_ROUND has an additional 'precise' flag which indicates the type of rounding.
24522454
// This is ignored by the pattern below where it is matched by (i64 timm0_1)
2455+
def : Pat<(nxv2f16 (AArch64fcvtr_mt nxv2i1:$Pg, nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
2456+
(FCVT_ZPmZ_StoH ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
24532457
def : Pat<(nxv2f16 (AArch64fcvtr_mt (nxv2i1 (SVEAllActive:$Pg)), nxv2f32:$Zs, (i64 timm0_1), nxv2f16:$Zd)),
24542458
(FCVT_ZPmZ_StoH_UNDEF ZPR:$Zd, PPR:$Pg, ZPR:$Zs)>;
24552459

2460+
def : Pat<(nxv4f32 (AArch64fcvte_mt nxv4i1:$Pg, nxv4bf16:$Zs, nxv4f32:$Zd)),
2461+
(SEL_ZPZZ_S $Pg, (LSL_ZZI_S $Zs, (i32 16)), $Zd)>;
2462+
def : Pat<(nxv2f32 (AArch64fcvte_mt nxv2i1:$Pg, nxv2bf16:$Zs, nxv2f32:$Zd)),
2463+
(SEL_ZPZZ_D $Pg, (LSL_ZZI_S $Zs, (i32 16)), $Zd)>;
2464+
24562465
def : Pat<(nxv4f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv4bf16:$op, undef)),
24572466
(LSL_ZZI_S $op, (i32 16))>;
24582467
def : Pat<(nxv2f32 (AArch64fcvte_mt (SVEAnyPredicate), nxv2bf16:$op, undef)),

0 commit comments

Comments
 (0)