Skip to content

Commit 655d0d8

Browse files
committed
[DAGCombine] Move AVG combine to SimplifyDemandBits
This moves the matching of AVGFloor and AVGCeil into a place where demand bit are available, so that it can detect more cases for more folds. It changes the transform to start from a shift, not from a truncate. We match the pattern shr(add(ext(A), ext(B)), 1), transforming to ext(hadd(A, B)). For signed values, because only the bottom bits are demanded llvm will transform the above to use a lshr too, as opposed to ashr. In order to correctly detect the hadd we need to know the demanded bits to turn it back. Depending on whether the shift is signed (ashr) or logical (lshr), and the extensions are signed or unsigned we can create different nodes. If the shift is signed: Needs >= 2 sign bits. https://alive2.llvm.org/ce/z/h4gQAW generating signed rhadd. Needs >= 2 zero bits. https://alive2.llvm.org/ce/z/B64DUA generating unsigned rhadd. If the shift is unsigned: Needs >= 1 zero bits. https://alive2.llvm.org/ce/z/ByD8sj generating unsigned rhadd. Needs 1 demanded bit zero and >= 2 sign bits https://alive2.llvm.org/ce/z/hvPGxX and https://alive2.llvm.org/ce/z/32P5n1 generating signed rhadd. Differential Revision: https://reviews.llvm.org/D119072
1 parent e7dcf09 commit 655d0d8

File tree

5 files changed

+255
-372
lines changed

5 files changed

+255
-372
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 0 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -12734,87 +12734,6 @@ SDValue DAGCombiner::visitEXTEND_VECTOR_INREG(SDNode *N) {
1273412734
return SDValue();
1273512735
}
1273612736

12737-
// Attempt to form one of the avg patterns from:
12738-
// truncate(shr(add(zext(OpB), zext(OpA)), 1))
12739-
// Creating avgflooru/avgfloors/avgceilu/avgceils, with the ceiling having an
12740-
// extra rounding add:
12741-
// truncate(shr(add(zext(OpB), zext(OpA), 1), 1))
12742-
// This starts at a truncate, meaning the shift will always be shl, as the top
12743-
// bits are known to not be demanded.
12744-
static SDValue performAvgCombine(SDNode *N, SelectionDAG &DAG) {
12745-
assert(N->getOpcode() == ISD::TRUNCATE && "TRUNCATE node expected");
12746-
EVT VT = N->getValueType(0);
12747-
12748-
SDValue Shift = N->getOperand(0);
12749-
if (Shift.getOpcode() != ISD::SRL)
12750-
return SDValue();
12751-
12752-
// Is the right shift using an immediate value of 1?
12753-
ConstantSDNode *N1C = isConstOrConstSplat(Shift.getOperand(1));
12754-
if (!N1C || !N1C->isOne())
12755-
return SDValue();
12756-
12757-
// We are looking for an avgfloor
12758-
// add(ext, ext)
12759-
// or one of these as a avgceil
12760-
// add(add(ext, ext), 1)
12761-
// add(add(ext, 1), ext)
12762-
// add(ext, add(ext, 1))
12763-
SDValue Add = Shift.getOperand(0);
12764-
if (Add.getOpcode() != ISD::ADD)
12765-
return SDValue();
12766-
12767-
SDValue ExtendOpA = Add.getOperand(0);
12768-
SDValue ExtendOpB = Add.getOperand(1);
12769-
auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3) {
12770-
ConstantSDNode *ConstOp;
12771-
if ((ConstOp = isConstOrConstSplat(Op1)) && ConstOp->isOne()) {
12772-
ExtendOpA = Op2;
12773-
ExtendOpB = Op3;
12774-
return true;
12775-
}
12776-
if ((ConstOp = isConstOrConstSplat(Op2)) && ConstOp->isOne()) {
12777-
ExtendOpA = Op1;
12778-
ExtendOpB = Op3;
12779-
return true;
12780-
}
12781-
if ((ConstOp = isConstOrConstSplat(Op3)) && ConstOp->isOne()) {
12782-
ExtendOpA = Op1;
12783-
ExtendOpB = Op2;
12784-
return true;
12785-
}
12786-
return false;
12787-
};
12788-
bool IsCeil = (ExtendOpA.getOpcode() == ISD::ADD &&
12789-
MatchOperands(ExtendOpA.getOperand(0), ExtendOpA.getOperand(1),
12790-
ExtendOpB)) ||
12791-
(ExtendOpB.getOpcode() == ISD::ADD &&
12792-
MatchOperands(ExtendOpB.getOperand(0), ExtendOpB.getOperand(1),
12793-
ExtendOpA));
12794-
12795-
unsigned ExtendOpAOpc = ExtendOpA.getOpcode();
12796-
unsigned ExtendOpBOpc = ExtendOpB.getOpcode();
12797-
if (!(ExtendOpAOpc == ExtendOpBOpc &&
12798-
(ExtendOpAOpc == ISD::ZERO_EXTEND || ExtendOpAOpc == ISD::SIGN_EXTEND)))
12799-
return SDValue();
12800-
12801-
// Is the result of the right shift being truncated to the same value type as
12802-
// the original operands, OpA and OpB?
12803-
SDValue OpA = ExtendOpA.getOperand(0);
12804-
SDValue OpB = ExtendOpB.getOperand(0);
12805-
EVT OpAVT = OpA.getValueType();
12806-
if (VT != OpAVT || OpAVT != OpB.getValueType())
12807-
return SDValue();
12808-
12809-
bool IsSignExtend = ExtendOpAOpc == ISD::SIGN_EXTEND;
12810-
unsigned AVGOpc = IsSignExtend ? (IsCeil ? ISD::AVGCEILS : ISD::AVGFLOORS)
12811-
: (IsCeil ? ISD::AVGCEILU : ISD::AVGFLOORU);
12812-
if (!DAG.getTargetLoweringInfo().isOperationLegalOrCustom(AVGOpc, VT))
12813-
return SDValue();
12814-
12815-
return DAG.getNode(AVGOpc, SDLoc(N), VT, OpA, OpB);
12816-
}
12817-
1281812737
SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1281912738
SDValue N0 = N->getOperand(0);
1282012739
EVT VT = N->getValueType(0);
@@ -13101,8 +13020,6 @@ SDValue DAGCombiner::visitTRUNCATE(SDNode *N) {
1310113020

1310213021
if (SDValue NewVSel = matchVSelectOpSizesWithSetCC(N))
1310313022
return NewVSel;
13104-
if (SDValue M = performAvgCombine(N, DAG))
13105-
return M;
1310613023

1310713024
// Narrow a suitable binary operation with a non-opaque constant operand by
1310813025
// moving it ahead of the truncate. This is limited to pre-legalization

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -907,6 +907,132 @@ SDValue TargetLowering::SimplifyMultipleUseDemandedVectorElts(
907907
Depth);
908908
}
909909

910+
// Attempt to form ext(avgfloor(A, B)) from shr(add(ext(A), ext(B)), 1).
911+
// or to form ext(avgceil(A, B)) from shr(add(ext(A), ext(B), 1), 1).
912+
static SDValue combineShiftToAVG(SDValue Op, SelectionDAG &DAG,
913+
const TargetLowering &TLI,
914+
const APInt &DemandedBits,
915+
const APInt &DemandedElts,
916+
unsigned Depth) {
917+
assert((Op.getOpcode() == ISD::SRL || Op.getOpcode() == ISD::SRA) &&
918+
"SRL or SRA node is required here!");
919+
// Is the right shift using an immediate value of 1?
920+
ConstantSDNode *N1C = isConstOrConstSplat(Op.getOperand(1), DemandedElts);
921+
if (!N1C || !N1C->isOne())
922+
return SDValue();
923+
924+
// We are looking for an avgfloor
925+
// add(ext, ext)
926+
// or one of these as a avgceil
927+
// add(add(ext, ext), 1)
928+
// add(add(ext, 1), ext)
929+
// add(ext, add(ext, 1))
930+
SDValue Add = Op.getOperand(0);
931+
if (Add.getOpcode() != ISD::ADD)
932+
return SDValue();
933+
934+
SDValue ExtOpA = Add.getOperand(0);
935+
SDValue ExtOpB = Add.getOperand(1);
936+
auto MatchOperands = [&](SDValue Op1, SDValue Op2, SDValue Op3) {
937+
ConstantSDNode *ConstOp;
938+
if ((ConstOp = isConstOrConstSplat(Op1, DemandedElts)) &&
939+
ConstOp->isOne()) {
940+
ExtOpA = Op2;
941+
ExtOpB = Op3;
942+
return true;
943+
}
944+
if ((ConstOp = isConstOrConstSplat(Op2, DemandedElts)) &&
945+
ConstOp->isOne()) {
946+
ExtOpA = Op1;
947+
ExtOpB = Op3;
948+
return true;
949+
}
950+
if ((ConstOp = isConstOrConstSplat(Op3, DemandedElts)) &&
951+
ConstOp->isOne()) {
952+
ExtOpA = Op1;
953+
ExtOpB = Op2;
954+
return true;
955+
}
956+
return false;
957+
};
958+
bool IsCeil =
959+
(ExtOpA.getOpcode() == ISD::ADD &&
960+
MatchOperands(ExtOpA.getOperand(0), ExtOpA.getOperand(1), ExtOpB)) ||
961+
(ExtOpB.getOpcode() == ISD::ADD &&
962+
MatchOperands(ExtOpB.getOperand(0), ExtOpB.getOperand(1), ExtOpA));
963+
964+
// If the shift is signed (sra):
965+
// - Needs >= 2 sign bit for both operands.
966+
// - Needs >= 2 zero bits.
967+
// If the shift is unsigned (srl):
968+
// - Needs >= 1 zero bit for both operands.
969+
// - Needs 1 demanded bit zero and >= 2 sign bits.
970+
unsigned ShiftOpc = Op.getOpcode();
971+
bool IsSigned = false;
972+
unsigned KnownBits;
973+
unsigned NumSignedA = DAG.ComputeNumSignBits(ExtOpA, DemandedElts, Depth);
974+
unsigned NumSignedB = DAG.ComputeNumSignBits(ExtOpB, DemandedElts, Depth);
975+
unsigned NumSigned = std::min(NumSignedA, NumSignedB) - 1;
976+
unsigned NumZeroA =
977+
DAG.computeKnownBits(ExtOpA, DemandedElts, Depth).countMinLeadingZeros();
978+
unsigned NumZeroB =
979+
DAG.computeKnownBits(ExtOpB, DemandedElts, Depth).countMinLeadingZeros();
980+
unsigned NumZero = std::min(NumZeroA, NumZeroB);
981+
982+
switch (ShiftOpc) {
983+
default:
984+
llvm_unreachable("Unexpected ShiftOpc in combineShiftToAVG");
985+
case ISD::SRA: {
986+
if (NumZero >= 2 && NumSigned < NumZero) {
987+
IsSigned = false;
988+
KnownBits = NumZero;
989+
break;
990+
}
991+
if (NumSigned >= 1) {
992+
IsSigned = true;
993+
KnownBits = NumSigned;
994+
break;
995+
}
996+
return SDValue();
997+
}
998+
case ISD::SRL: {
999+
if (NumZero >= 1 && NumSigned < NumZero) {
1000+
IsSigned = false;
1001+
KnownBits = NumZero;
1002+
break;
1003+
}
1004+
if (NumSigned >= 1 && DemandedBits.isSignBitClear()) {
1005+
IsSigned = true;
1006+
KnownBits = NumSigned;
1007+
break;
1008+
}
1009+
return SDValue();
1010+
}
1011+
}
1012+
1013+
unsigned AVGOpc = IsCeil ? (IsSigned ? ISD::AVGCEILS : ISD::AVGCEILU)
1014+
: (IsSigned ? ISD::AVGFLOORS : ISD::AVGFLOORU);
1015+
1016+
// Find the smallest power-2 type that is legal for this vector size and
1017+
// operation, given the original type size and the number of known sign/zero
1018+
// bits.
1019+
EVT VT = Op.getValueType();
1020+
unsigned MinWidth =
1021+
std::max<unsigned>(VT.getScalarSizeInBits() - KnownBits, 8);
1022+
EVT NVT = EVT::getIntegerVT(*DAG.getContext(), PowerOf2Ceil(MinWidth));
1023+
if (VT.isVector())
1024+
NVT = EVT::getVectorVT(*DAG.getContext(), NVT, VT.getVectorElementCount());
1025+
if (!TLI.isOperationLegalOrCustom(AVGOpc, NVT))
1026+
return SDValue();
1027+
1028+
SDLoc DL(Op);
1029+
SDValue ResultAVG =
1030+
DAG.getNode(AVGOpc, DL, NVT, DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpA),
1031+
DAG.getNode(ISD::TRUNCATE, DL, NVT, ExtOpB));
1032+
return DAG.getNode(IsSigned ? ISD::SIGN_EXTEND : ISD::ZERO_EXTEND, DL, VT,
1033+
ResultAVG);
1034+
}
1035+
9101036
/// Look at Op. At this point, we know that only the OriginalDemandedBits of the
9111037
/// result of Op are ever used downstream. If we can use this information to
9121038
/// simplify Op, create a new simplified DAG node and return true, returning the
@@ -1569,6 +1695,11 @@ bool TargetLowering::SimplifyDemandedBits(
15691695
SDValue Op1 = Op.getOperand(1);
15701696
EVT ShiftVT = Op1.getValueType();
15711697

1698+
// Try to match AVG patterns.
1699+
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1700+
DemandedElts, Depth + 1))
1701+
return TLO.CombineTo(Op, AVG);
1702+
15721703
if (const APInt *SA =
15731704
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
15741705
unsigned ShAmt = SA->getZExtValue();
@@ -1635,6 +1766,11 @@ bool TargetLowering::SimplifyDemandedBits(
16351766
if (DemandedBits.isOne())
16361767
return TLO.CombineTo(Op, TLO.DAG.getNode(ISD::SRL, dl, VT, Op0, Op1));
16371768

1769+
// Try to match AVG patterns.
1770+
if (SDValue AVG = combineShiftToAVG(Op, TLO.DAG, *this, DemandedBits,
1771+
DemandedElts, Depth + 1))
1772+
return TLO.CombineTo(Op, AVG);
1773+
16381774
if (const APInt *SA =
16391775
TLO.DAG.getValidShiftAmountConstant(Op, DemandedElts)) {
16401776
unsigned ShAmt = SA->getZExtValue();

0 commit comments

Comments
 (0)