Skip to content

Commit 338fd8b

Browse files
authored
[SimplifyCFG] Transform switch to select when common bits uniquely identify one case (#145233)
Fix #141753 . This patch introduces a new check, that tries to decide if the conjunction of all the values uniquely identify the accepted values by the switch.
1 parent 68173c8 commit 338fd8b

File tree

2 files changed

+468
-4
lines changed

2 files changed

+468
-4
lines changed

llvm/lib/Transforms/Utils/SimplifyCFG.cpp

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6198,7 +6198,7 @@ static bool initializeUniqueCases(SwitchInst *SI, PHINode *&PHI,
61986198
// TODO: Handle switches with more than 2 cases that map to the same result.
61996199
static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
62006200
Constant *DefaultResult, Value *Condition,
6201-
IRBuilder<> &Builder) {
6201+
IRBuilder<> &Builder, const DataLayout &DL) {
62026202
// If we are selecting between only two cases transform into a simple
62036203
// select or a two-way select if default is possible.
62046204
// Example:
@@ -6234,10 +6234,33 @@ static Value *foldSwitchToSelect(const SwitchCaseResultVectorTy &ResultVector,
62346234
// case 0,2,8,10 -> Cond & 0b1..0101 == 0 ? result : default
62356235
if (isPowerOf2_32(CaseCount)) {
62366236
ConstantInt *MinCaseVal = CaseValues[0];
6237-
// Find mininal value.
6238-
for (auto *Case : CaseValues)
6237+
// If there are bits that are set exclusively by CaseValues, we
6238+
// can transform the switch into a select if the conjunction of
6239+
// all the values uniquely identify CaseValues.
6240+
APInt AndMask = APInt::getAllOnes(MinCaseVal->getBitWidth());
6241+
6242+
// Find the minimum value and compute the and of all the case values.
6243+
for (auto *Case : CaseValues) {
62396244
if (Case->getValue().slt(MinCaseVal->getValue()))
62406245
MinCaseVal = Case;
6246+
AndMask &= Case->getValue();
6247+
}
6248+
KnownBits Known = computeKnownBits(Condition, DL);
6249+
6250+
if (!AndMask.isZero() && Known.getMaxValue().uge(AndMask)) {
6251+
// Compute the number of bits that are free to vary.
6252+
unsigned FreeBits = Known.countMaxActiveBits() - AndMask.popcount();
6253+
6254+
// Check if the number of values covered by the mask is equal
6255+
// to the number of cases.
6256+
if (FreeBits == Log2_32(CaseCount)) {
6257+
Value *And = Builder.CreateAnd(Condition, AndMask);
6258+
Value *Cmp = Builder.CreateICmpEQ(
6259+
And, Constant::getIntegerValue(And->getType(), AndMask));
6260+
return Builder.CreateSelect(Cmp, ResultVector[0].first,
6261+
DefaultResult);
6262+
}
6263+
}
62416264

62426265
// Mark the bits case number touched.
62436266
APInt BitMask = APInt::getZero(MinCaseVal->getBitWidth());
@@ -6325,7 +6348,7 @@ static bool trySwitchToSelect(SwitchInst *SI, IRBuilder<> &Builder,
63256348
assert(PHI != nullptr && "PHI for value select not found");
63266349
Builder.SetInsertPoint(SI);
63276350
Value *SelectValue =
6328-
foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder);
6351+
foldSwitchToSelect(UniqueResults, DefaultResult, Cond, Builder, DL);
63296352
if (!SelectValue)
63306353
return false;
63316354

0 commit comments

Comments
 (0)