Skip to content

Commit c80fa23

Browse files
authored
[DAG] SDPatternMatch m_Zero/m_One/m_AllOnes have inconsistent undef h… (#147044)
### Summary This PR resolves #146871 This PR resolves #140745 Refactor m_Zero/m_One/m_AllOnes all use struct template function to match and AllowUndefs=false as default.
1 parent acb4fff commit c80fa23

File tree

5 files changed

+143
-7
lines changed

5 files changed

+143
-7
lines changed

llvm/include/llvm/CodeGen/SDPatternMatch.h

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,19 +1100,46 @@ inline SpecificInt_match m_SpecificInt(uint64_t V) {
11001100
return SpecificInt_match(APInt(64, V));
11011101
}
11021102

1103-
inline SpecificInt_match m_Zero() { return m_SpecificInt(0U); }
1104-
inline SpecificInt_match m_One() { return m_SpecificInt(1U); }
1103+
struct Zero_match {
1104+
bool AllowUndefs;
1105+
1106+
explicit Zero_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1107+
1108+
template <typename MatchContext>
1109+
bool match(const MatchContext &, SDValue N) const {
1110+
return isZeroOrZeroSplat(N, AllowUndefs);
1111+
}
1112+
};
1113+
1114+
struct Ones_match {
1115+
bool AllowUndefs;
1116+
1117+
Ones_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
1118+
1119+
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1120+
return isOnesOrOnesSplat(N, AllowUndefs);
1121+
}
1122+
};
11051123

11061124
struct AllOnes_match {
1125+
bool AllowUndefs;
11071126

1108-
AllOnes_match() = default;
1127+
AllOnes_match(bool AllowUndefs) : AllowUndefs(AllowUndefs) {}
11091128

11101129
template <typename MatchContext> bool match(const MatchContext &, SDValue N) {
1111-
return isAllOnesOrAllOnesSplat(N);
1130+
return isAllOnesOrAllOnesSplat(N, AllowUndefs);
11121131
}
11131132
};
11141133

1115-
inline AllOnes_match m_AllOnes() { return AllOnes_match(); }
1134+
inline Ones_match m_One(bool AllowUndefs = false) {
1135+
return Ones_match(AllowUndefs);
1136+
}
1137+
inline Zero_match m_Zero(bool AllowUndefs = false) {
1138+
return Zero_match(AllowUndefs);
1139+
}
1140+
inline AllOnes_match m_AllOnes(bool AllowUndefs = false) {
1141+
return AllOnes_match(AllowUndefs);
1142+
}
11161143

11171144
/// Match true boolean value based on the information provided by
11181145
/// TargetLowering.
@@ -1189,7 +1216,7 @@ inline CondCode_match m_SpecificCondCode(ISD::CondCode CC) {
11891216

11901217
/// Match a negate as a sub(0, v)
11911218
template <typename ValTy>
1192-
inline BinaryOpc_match<SpecificInt_match, ValTy> m_Neg(const ValTy &V) {
1219+
inline BinaryOpc_match<Zero_match, ValTy, false> m_Neg(const ValTy &V) {
11931220
return m_Sub(m_Zero(), V);
11941221
}
11951222

llvm/include/llvm/CodeGen/SelectionDAGNodes.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1937,6 +1937,16 @@ LLVM_ABI bool isOneOrOneSplat(SDValue V, bool AllowUndefs = false);
19371937
/// Does not permit build vector implicit truncation.
19381938
LLVM_ABI bool isAllOnesOrAllOnesSplat(SDValue V, bool AllowUndefs = false);
19391939

1940+
/// Return true if the value is a constant 1 integer or a splatted vector of a
1941+
/// constant 1 integer (with no undefs).
1942+
/// Does not permit build vector implicit truncation.
1943+
LLVM_ABI bool isOnesOrOnesSplat(SDValue N, bool AllowUndefs = false);
1944+
1945+
/// Return true if the value is a constant 0 integer or a splatted vector of a
1946+
/// constant 0 integer (with no undefs).
1947+
/// Does not permit build vector implicit truncation.
1948+
LLVM_ABI bool isZeroOrZeroSplat(SDValue N, bool AllowUndefs = false);
1949+
19401950
/// Return true if \p V is either a integer or FP constant.
19411951
inline bool isIntOrFPConstant(SDValue V) {
19421952
return isa<ConstantSDNode>(V) || isa<ConstantFPSDNode>(V);

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4280,7 +4280,8 @@ SDValue DAGCombiner::visitSUB(SDNode *N) {
42804280
return V;
42814281

42824282
// (A - B) - 1 -> add (xor B, -1), A
4283-
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))), m_One())))
4283+
if (sd_match(N, m_Sub(m_OneUse(m_Sub(m_Value(A), m_Value(B))),
4284+
m_One(/*AllowUndefs=*/true))))
42844285
return DAG.getNode(ISD::ADD, DL, VT, A, DAG.getNOT(DL, B, VT));
42854286

42864287
// Look for:

llvm/lib/CodeGen/SelectionDAG/SelectionDAG.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12644,6 +12644,18 @@ bool llvm::isAllOnesOrAllOnesSplat(SDValue N, bool AllowUndefs) {
1264412644
return C && C->isAllOnes() && C->getValueSizeInBits(0) == BitWidth;
1264512645
}
1264612646

12647+
bool llvm::isOnesOrOnesSplat(SDValue N, bool AllowUndefs) {
12648+
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs);
12649+
return C && APInt::isSameValue(C->getAPIntValue(),
12650+
APInt(C->getAPIntValue().getBitWidth(), 1));
12651+
}
12652+
12653+
bool llvm::isZeroOrZeroSplat(SDValue N, bool AllowUndefs) {
12654+
N = peekThroughBitcasts(N);
12655+
ConstantSDNode *C = isConstOrConstSplat(N, AllowUndefs, true);
12656+
return C && C->isZero();
12657+
}
12658+
1264712659
HandleSDNode::~HandleSDNode() {
1264812660
DropOperands();
1264912661
}

llvm/unittests/CodeGen/SelectionDAGPatternMatchTest.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,3 +825,89 @@ TEST_F(SelectionDAGPatternMatchTest, matchReassociatableOp) {
825825
EXPECT_FALSE(sd_match(
826826
ORS0123, m_ReassociatableOr(m_Value(), m_Value(), m_Value(), m_Value())));
827827
}
828+
829+
TEST_F(SelectionDAGPatternMatchTest, MatchZeroOneAllOnes) {
830+
using namespace SDPatternMatch;
831+
832+
SDLoc DL;
833+
EVT VT = EVT::getIntegerVT(Context, 32);
834+
835+
// Scalar constant 0
836+
SDValue Zero = DAG->getConstant(0, DL, VT);
837+
EXPECT_TRUE(sd_match(Zero, DAG.get(), llvm::SDPatternMatch::m_Zero()));
838+
EXPECT_FALSE(sd_match(Zero, DAG.get(), m_One()));
839+
EXPECT_FALSE(sd_match(Zero, DAG.get(), m_AllOnes()));
840+
841+
// Scalar constant 1
842+
SDValue One = DAG->getConstant(1, DL, VT);
843+
EXPECT_FALSE(sd_match(One, DAG.get(), m_Zero()));
844+
EXPECT_TRUE(sd_match(One, DAG.get(), m_One()));
845+
EXPECT_FALSE(sd_match(One, DAG.get(), m_AllOnes()));
846+
847+
// Scalar constant -1
848+
SDValue AllOnes =
849+
DAG->getConstant(APInt::getAllOnes(VT.getSizeInBits()), DL, VT);
850+
EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_Zero()));
851+
EXPECT_FALSE(sd_match(AllOnes, DAG.get(), m_One()));
852+
EXPECT_TRUE(sd_match(AllOnes, DAG.get(), m_AllOnes()));
853+
854+
EVT VecF32 = EVT::getVectorVT(Context, MVT::f32, 4);
855+
EVT VecVT = EVT::getVectorVT(Context, MVT::i32, 4);
856+
857+
// m_Zero: splat vector of 0 → bitcast
858+
{
859+
SDValue SplatVal = DAG->getConstant(0, DL, MVT::i32);
860+
SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal);
861+
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat);
862+
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_Zero()));
863+
}
864+
865+
// m_One: splat vector of 1 → bitcast
866+
{
867+
SDValue SplatVal = DAG->getConstant(1, DL, MVT::i32);
868+
SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal);
869+
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat);
870+
EXPECT_FALSE(sd_match(Bitcasted, DAG.get(), m_One()));
871+
}
872+
873+
// m_AllOnes: splat vector of -1 → bitcast
874+
{
875+
SDValue SplatVal = DAG->getConstant(APInt::getAllOnes(32), DL, MVT::i32);
876+
SDValue VecSplat = DAG->getSplatBuildVector(VecVT, DL, SplatVal);
877+
SDValue Bitcasted = DAG->getNode(ISD::BITCAST, DL, VecF32, VecSplat);
878+
EXPECT_TRUE(sd_match(Bitcasted, DAG.get(), m_AllOnes()));
879+
}
880+
881+
// splat vector with one undef → default should NOT match
882+
SDValue Undef = DAG->getUNDEF(MVT::i32);
883+
884+
{
885+
// m_Zero: Undef + constant 0
886+
SDValue Zero = DAG->getConstant(0, DL, MVT::i32);
887+
SmallVector<SDValue, 4> Ops(4, Zero);
888+
Ops[2] = Undef;
889+
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
890+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_Zero()));
891+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_Zero(true)));
892+
}
893+
894+
{
895+
// m_One: Undef + constant 1
896+
SDValue One = DAG->getConstant(1, DL, MVT::i32);
897+
SmallVector<SDValue, 4> Ops(4, One);
898+
Ops[1] = Undef;
899+
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
900+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_One()));
901+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_One(true)));
902+
}
903+
904+
{
905+
// m_AllOnes: Undef + constant -1
906+
SDValue AllOnes = DAG->getConstant(APInt::getAllOnes(32), DL, MVT::i32);
907+
SmallVector<SDValue, 4> Ops(4, AllOnes);
908+
Ops[0] = Undef;
909+
SDValue Vec = DAG->getBuildVector(VecVT, DL, Ops);
910+
EXPECT_FALSE(sd_match(Vec, DAG.get(), m_AllOnes()));
911+
EXPECT_TRUE(sd_match(Vec, DAG.get(), m_AllOnes(true)));
912+
}
913+
}

0 commit comments

Comments
 (0)