Skip to content

Commit fd2bb51

Browse files
committed
[ADT] Add APInt/MathExtras isShiftedMask variant returning mask offset/length
In many cases, calls to isShiftedMask are immediately followed with checks to determine the size and position of the bitmask. This patch adds variants of APInt::isShiftedMask, isShiftedMask_32 and isShiftedMask_64 that return these values as additional arguments. I've updated a number of cases that were either performing seperate size/position calculations or had created their own local wrapper versions of these. Differential Revision: https://reviews.llvm.org/D119019
1 parent 83f9b13 commit fd2bb51

File tree

8 files changed

+112
-36
lines changed

8 files changed

+112
-36
lines changed

llvm/include/llvm/ADT/APInt.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -496,6 +496,23 @@ class LLVM_NODISCARD APInt {
496496
return (Ones + LeadZ + countTrailingZeros()) == BitWidth;
497497
}
498498

499+
/// Return true if this APInt value contains a non-empty sequence of ones with
500+
/// the remainder zero. If true, \p MaskIdx will specify the index of the
501+
/// lowest set bit and \p MaskLen is updated to specify the length of the
502+
/// mask, else neither are updated.
503+
bool isShiftedMask(unsigned &MaskIdx, unsigned &MaskLen) const {
504+
if (isSingleWord())
505+
return isShiftedMask_64(U.VAL, MaskIdx, MaskLen);
506+
unsigned Ones = countPopulationSlowCase();
507+
unsigned LeadZ = countLeadingZerosSlowCase();
508+
unsigned TrailZ = countTrailingZerosSlowCase();
509+
if ((Ones + LeadZ + TrailZ) != BitWidth)
510+
return false;
511+
MaskLen = Ones;
512+
MaskIdx = TrailZ;
513+
return true;
514+
}
515+
499516
/// Compute an APInt containing numBits highbits from this APInt.
500517
///
501518
/// Get an APInt with the same BitWidth as this APInt, just zero mask the low

llvm/include/llvm/Support/MathExtras.h

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,33 @@ inline unsigned countPopulation(T Value) {
571571
return detail::PopulationCounter<T, sizeof(T)>::count(Value);
572572
}
573573

574+
/// Return true if the argument contains a non-empty sequence of ones with the
575+
/// remainder zero (32 bit version.) Ex. isShiftedMask_32(0x0000FF00U) == true.
576+
/// If true, \p MaskIdx will specify the index of the lowest set bit and \p
577+
/// MaskLen is updated to specify the length of the mask, else neither are
578+
/// updated.
579+
inline bool isShiftedMask_32(uint32_t Value, unsigned &MaskIdx,
580+
unsigned &MaskLen) {
581+
if (!isShiftedMask_32(Value))
582+
return false;
583+
MaskIdx = countTrailingZeros(Value);
584+
MaskLen = countPopulation(Value);
585+
return true;
586+
}
587+
588+
/// Return true if the argument contains a non-empty sequence of ones with the
589+
/// remainder zero (64 bit version.) If true, \p MaskIdx will specify the index
590+
/// of the lowest set bit and \p MaskLen is updated to specify the length of the
591+
/// mask, else neither are updated.
592+
inline bool isShiftedMask_64(uint64_t Value, unsigned &MaskIdx,
593+
unsigned &MaskLen) {
594+
if (!isShiftedMask_64(Value))
595+
return false;
596+
MaskIdx = countTrailingZeros(Value);
597+
MaskLen = countPopulation(Value);
598+
return true;
599+
}
600+
574601
/// Compile time Log2.
575602
/// Valid only for positive powers of two.
576603
template <size_t kValue> constexpr inline size_t CTLog2() {

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12254,10 +12254,7 @@ SDValue DAGCombiner::reduceLoadWidth(SDNode *N) {
1225412254
unsigned ActiveBits = 0;
1225512255
if (Mask.isMask()) {
1225612256
ActiveBits = Mask.countTrailingOnes();
12257-
} else if (Mask.isShiftedMask()) {
12258-
ShAmt = Mask.countTrailingZeros();
12259-
APInt ShiftedMask = Mask.lshr(ShAmt);
12260-
ActiveBits = ShiftedMask.countTrailingOnes();
12257+
} else if (Mask.isShiftedMask(ShAmt, ActiveBits)) {
1226112258
HasShiftedOffset = true;
1226212259
} else
1226312260
return SDValue();

llvm/lib/Target/AMDGPU/AMDGPUISelLowering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3281,8 +3281,9 @@ SDValue AMDGPUTargetLowering::performSrlCombine(SDNode *N,
32813281
// this improves the ability to match BFE patterns in isel.
32823282
if (LHS.getOpcode() == ISD::AND) {
32833283
if (auto *Mask = dyn_cast<ConstantSDNode>(LHS.getOperand(1))) {
3284-
if (Mask->getAPIntValue().isShiftedMask() &&
3285-
Mask->getAPIntValue().countTrailingZeros() == ShiftAmt) {
3284+
unsigned MaskIdx, MaskLen;
3285+
if (Mask->getAPIntValue().isShiftedMask(MaskIdx, MaskLen) &&
3286+
MaskIdx == ShiftAmt) {
32863287
return DAG.getNode(
32873288
ISD::AND, SL, VT,
32883289
DAG.getNode(ISD::SRL, SL, VT, LHS.getOperand(0), N->getOperand(1)),

llvm/lib/Target/Mips/MipsISelLowering.cpp

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -94,18 +94,6 @@ static const MCPhysReg Mips64DPRegs[8] = {
9494
Mips::D16_64, Mips::D17_64, Mips::D18_64, Mips::D19_64
9595
};
9696

97-
// If I is a shifted mask, set the size (Size) and the first bit of the
98-
// mask (Pos), and return true.
99-
// For example, if I is 0x003ff800, (Pos, Size) = (11, 11).
100-
static bool isShiftedMask(uint64_t I, uint64_t &Pos, uint64_t &Size) {
101-
if (!isShiftedMask_64(I))
102-
return false;
103-
104-
Size = countPopulation(I);
105-
Pos = countTrailingZeros(I);
106-
return true;
107-
}
108-
10997
// The MIPS MSA ABI passes vector arguments in the integer register set.
11098
// The number of integer registers used is dependant on the ABI used.
11199
MVT MipsTargetLowering::getRegisterTypeForCallingConv(LLVMContext &Context,
@@ -794,14 +782,15 @@ static SDValue performANDCombine(SDNode *N, SelectionDAG &DAG,
794782
EVT ValTy = N->getValueType(0);
795783
SDLoc DL(N);
796784

797-
uint64_t Pos = 0, SMPos, SMSize;
785+
uint64_t Pos = 0;
786+
unsigned SMPos, SMSize;
798787
ConstantSDNode *CN;
799788
SDValue NewOperand;
800789
unsigned Opc;
801790

802791
// Op's second operand must be a shifted mask.
803792
if (!(CN = dyn_cast<ConstantSDNode>(Mask)) ||
804-
!isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
793+
!isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
805794
return SDValue();
806795

807796
if (FirstOperandOpc == ISD::SRA || FirstOperandOpc == ISD::SRL) {
@@ -875,23 +864,23 @@ static SDValue performORCombine(SDNode *N, SelectionDAG &DAG,
875864
return SDValue();
876865

877866
SDValue And0 = N->getOperand(0), And1 = N->getOperand(1);
878-
uint64_t SMPos0, SMSize0, SMPos1, SMSize1;
867+
unsigned SMPos0, SMSize0, SMPos1, SMSize1;
879868
ConstantSDNode *CN, *CN1;
880869

881870
// See if Op's first operand matches (and $src1 , mask0).
882871
if (And0.getOpcode() != ISD::AND)
883872
return SDValue();
884873

885874
if (!(CN = dyn_cast<ConstantSDNode>(And0.getOperand(1))) ||
886-
!isShiftedMask(~CN->getSExtValue(), SMPos0, SMSize0))
875+
!isShiftedMask_64(~CN->getSExtValue(), SMPos0, SMSize0))
887876
return SDValue();
888877

889878
// See if Op's second operand matches (and (shl $src, pos), mask1).
890879
if (And1.getOpcode() == ISD::AND &&
891880
And1.getOperand(0).getOpcode() == ISD::SHL) {
892881

893882
if (!(CN = dyn_cast<ConstantSDNode>(And1.getOperand(1))) ||
894-
!isShiftedMask(CN->getZExtValue(), SMPos1, SMSize1))
883+
!isShiftedMask_64(CN->getZExtValue(), SMPos1, SMSize1))
895884
return SDValue();
896885

897886
// The shift masks must have the same position and size.
@@ -1118,7 +1107,8 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
11181107
EVT ValTy = N->getValueType(0);
11191108
SDLoc DL(N);
11201109

1121-
uint64_t Pos = 0, SMPos, SMSize;
1110+
uint64_t Pos = 0;
1111+
unsigned SMPos, SMSize;
11221112
ConstantSDNode *CN;
11231113
SDValue NewOperand;
11241114

@@ -1136,7 +1126,7 @@ static SDValue performSHLCombine(SDNode *N, SelectionDAG &DAG,
11361126

11371127
// AND's second operand must be a shifted mask.
11381128
if (!(CN = dyn_cast<ConstantSDNode>(FirstOperand.getOperand(1))) ||
1139-
!isShiftedMask(CN->getZExtValue(), SMPos, SMSize))
1129+
!isShiftedMask_64(CN->getZExtValue(), SMPos, SMSize))
11401130
return SDValue();
11411131

11421132
// Return if the shifted mask does not start at bit 0 or the sum of its size

llvm/lib/Target/X86/X86InstCombineIntrinsic.cpp

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -996,20 +996,18 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
996996
return IC.replaceInstUsesWith(II, II.getArgOperand(0));
997997
}
998998

999-
if (MaskC->getValue().isShiftedMask()) {
999+
unsigned MaskIdx, MaskLen;
1000+
if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
10001001
// any single contingous sequence of 1s anywhere in the mask simply
10011002
// describes a subset of the input bits shifted to the appropriate
10021003
// position. Replace with the straight forward IR.
1003-
unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
10041004
Value *Input = II.getArgOperand(0);
10051005
Value *Masked = IC.Builder.CreateAnd(Input, II.getArgOperand(1));
1006-
Value *Shifted = IC.Builder.CreateLShr(Masked,
1007-
ConstantInt::get(II.getType(),
1008-
ShiftAmount));
1006+
Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
1007+
Value *Shifted = IC.Builder.CreateLShr(Masked, ShiftAmt);
10091008
return IC.replaceInstUsesWith(II, Shifted);
10101009
}
10111010

1012-
10131011
if (auto *SrcC = dyn_cast<ConstantInt>(II.getArgOperand(0))) {
10141012
uint64_t Src = SrcC->getZExtValue();
10151013
uint64_t Mask = MaskC->getZExtValue();
@@ -1041,15 +1039,15 @@ X86TTIImpl::instCombineIntrinsic(InstCombiner &IC, IntrinsicInst &II) const {
10411039
if (MaskC->isAllOnesValue()) {
10421040
return IC.replaceInstUsesWith(II, II.getArgOperand(0));
10431041
}
1044-
if (MaskC->getValue().isShiftedMask()) {
1042+
1043+
unsigned MaskIdx, MaskLen;
1044+
if (MaskC->getValue().isShiftedMask(MaskIdx, MaskLen)) {
10451045
// any single contingous sequence of 1s anywhere in the mask simply
10461046
// describes a subset of the input bits shifted to the appropriate
10471047
// position. Replace with the straight forward IR.
1048-
unsigned ShiftAmount = MaskC->getValue().countTrailingZeros();
10491048
Value *Input = II.getArgOperand(0);
1050-
Value *Shifted = IC.Builder.CreateShl(Input,
1051-
ConstantInt::get(II.getType(),
1052-
ShiftAmount));
1049+
Value *ShiftAmt = ConstantInt::get(II.getType(), MaskIdx);
1050+
Value *Shifted = IC.Builder.CreateShl(Input, ShiftAmt);
10531051
Value *Masked = IC.Builder.CreateAnd(Shifted, II.getArgOperand(1));
10541052
return IC.replaceInstUsesWith(II, Masked);
10551053
}

llvm/unittests/ADT/APIntTest.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,21 +1746,43 @@ TEST(APIntTest, isShiftedMask) {
17461746
EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask());
17471747
EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask());
17481748

1749+
unsigned MaskIdx, MaskLen;
1750+
EXPECT_FALSE(APInt(32, 0x01010101).isShiftedMask(MaskIdx, MaskLen));
1751+
EXPECT_TRUE(APInt(32, 0xf0000000).isShiftedMask(MaskIdx, MaskLen));
1752+
EXPECT_EQ(28, MaskIdx);
1753+
EXPECT_EQ(4, MaskLen);
1754+
EXPECT_TRUE(APInt(32, 0xffff0000).isShiftedMask(MaskIdx, MaskLen));
1755+
EXPECT_EQ(16, MaskIdx);
1756+
EXPECT_EQ(16, MaskLen);
1757+
EXPECT_TRUE(APInt(32, 0xff << 1).isShiftedMask(MaskIdx, MaskLen));
1758+
EXPECT_EQ(1, MaskIdx);
1759+
EXPECT_EQ(8, MaskLen);
1760+
17491761
for (int N : { 1, 2, 3, 4, 7, 8, 16, 32, 64, 127, 128, 129, 256 }) {
17501762
EXPECT_FALSE(APInt(N, 0).isShiftedMask());
1763+
EXPECT_FALSE(APInt(N, 0).isShiftedMask(MaskIdx, MaskLen));
17511764

17521765
APInt One(N, 1);
17531766
for (int I = 1; I < N; ++I) {
17541767
APInt MaskVal = One.shl(I) - 1;
17551768
EXPECT_TRUE(MaskVal.isShiftedMask());
1769+
EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
1770+
EXPECT_EQ(0, MaskIdx);
1771+
EXPECT_EQ(I, MaskLen);
17561772
}
17571773
for (int I = 1; I < N - 1; ++I) {
17581774
APInt MaskVal = One.shl(I);
17591775
EXPECT_TRUE(MaskVal.isShiftedMask());
1776+
EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
1777+
EXPECT_EQ(I, MaskIdx);
1778+
EXPECT_EQ(1, MaskLen);
17601779
}
17611780
for (int I = 1; I < N; ++I) {
17621781
APInt MaskVal = APInt::getHighBitsSet(N, I);
17631782
EXPECT_TRUE(MaskVal.isShiftedMask());
1783+
EXPECT_TRUE(MaskVal.isShiftedMask(MaskIdx, MaskLen));
1784+
EXPECT_EQ(N - I, MaskIdx);
1785+
EXPECT_EQ(I, MaskLen);
17641786
}
17651787
}
17661788
}

llvm/unittests/Support/MathExtrasTest.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,13 +180,37 @@ TEST(MathExtras, isShiftedMask_32) {
180180
EXPECT_TRUE(isShiftedMask_32(0xf0000000));
181181
EXPECT_TRUE(isShiftedMask_32(0xffff0000));
182182
EXPECT_TRUE(isShiftedMask_32(0xff << 1));
183+
184+
unsigned MaskIdx, MaskLen;
185+
EXPECT_FALSE(isShiftedMask_32(0x01010101, MaskIdx, MaskLen));
186+
EXPECT_TRUE(isShiftedMask_32(0xf0000000, MaskIdx, MaskLen));
187+
EXPECT_EQ(28, MaskIdx);
188+
EXPECT_EQ(4, MaskLen);
189+
EXPECT_TRUE(isShiftedMask_32(0xffff0000, MaskIdx, MaskLen));
190+
EXPECT_EQ(16, MaskIdx);
191+
EXPECT_EQ(16, MaskLen);
192+
EXPECT_TRUE(isShiftedMask_32(0xff << 1, MaskIdx, MaskLen));
193+
EXPECT_EQ(1, MaskIdx);
194+
EXPECT_EQ(8, MaskLen);
183195
}
184196

185197
TEST(MathExtras, isShiftedMask_64) {
186198
EXPECT_FALSE(isShiftedMask_64(0x0101010101010101ull));
187199
EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull));
188200
EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull));
189201
EXPECT_TRUE(isShiftedMask_64(0xffull << 55));
202+
203+
unsigned MaskIdx, MaskLen;
204+
EXPECT_FALSE(isShiftedMask_64(0x0101010101010101ull, MaskIdx, MaskLen));
205+
EXPECT_TRUE(isShiftedMask_64(0xf000000000000000ull, MaskIdx, MaskLen));
206+
EXPECT_EQ(60, MaskIdx);
207+
EXPECT_EQ(4, MaskLen);
208+
EXPECT_TRUE(isShiftedMask_64(0xffff000000000000ull, MaskIdx, MaskLen));
209+
EXPECT_EQ(48, MaskIdx);
210+
EXPECT_EQ(16, MaskLen);
211+
EXPECT_TRUE(isShiftedMask_64(0xffull << 55, MaskIdx, MaskLen));
212+
EXPECT_EQ(55, MaskIdx);
213+
EXPECT_EQ(8, MaskLen);
190214
}
191215

192216
TEST(MathExtras, isPowerOf2_32) {

0 commit comments

Comments
 (0)