Skip to content

Commit 0da9aac

Browse files
authored
[InstCombine] Extend bitmask mul combine to handle independent operands (#142503)
This extends #136013 to capture cases where the combineable bitmask muls are nested under multiple or-disjoints. This PR is meant for commits starting at 8c403c9 op1 = or-disjoint mul(and (X, C1), D) , reg1 op2 = or-disjoint mul(and (X, C2), D) , reg2 out = or-disjoint op1, op2 -> temp1 = or-disjoint reg1, reg2 out = or-disjoint mul(and (X, (C1 + C2)), D), temp1 Case1: https://alive2.llvm.org/ce/z/dHApyV Case2: https://alive2.llvm.org/ce/z/Jz-Nag Case3: https://alive2.llvm.org/ce/z/3xBnEV
1 parent ff1b37b commit 0da9aac

File tree

3 files changed

+509
-22
lines changed

3 files changed

+509
-22
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 63 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
36023602
APInt Mask;
36033603
bool NUW;
36043604
bool NSW;
3605+
3606+
bool isCombineableWith(const DecomposedBitMaskMul Other) {
3607+
return X == Other.X && !Mask.intersects(Other.Mask) &&
3608+
Factor == Other.Factor;
3609+
}
36053610
};
36063611

36073612
static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
@@ -3659,6 +3664,59 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
36593664
return std::nullopt;
36603665
}
36613666

3667+
/// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3668+
/// This also accepts the equivalent select form of (A & N) * C
3669+
/// expressions i.e. !(A & N) ? 0 : N * C)
3670+
static Value *foldBitmaskMul(Value *Op0, Value *Op1,
3671+
InstCombiner::BuilderTy &Builder) {
3672+
auto Decomp1 = matchBitmaskMul(Op1);
3673+
if (!Decomp1)
3674+
return nullptr;
3675+
3676+
auto Decomp0 = matchBitmaskMul(Op0);
3677+
if (!Decomp0)
3678+
return nullptr;
3679+
3680+
if (Decomp0->isCombineableWith(*Decomp1)) {
3681+
Value *NewAnd = Builder.CreateAnd(
3682+
Decomp0->X,
3683+
ConstantInt::get(Decomp0->X->getType(), Decomp0->Mask + Decomp1->Mask));
3684+
3685+
return Builder.CreateMul(
3686+
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor), "",
3687+
Decomp0->NUW && Decomp1->NUW, Decomp0->NSW && Decomp1->NSW);
3688+
}
3689+
3690+
return nullptr;
3691+
}
3692+
3693+
Value *InstCombinerImpl::foldDisjointOr(Value *LHS, Value *RHS) {
3694+
if (Value *Res = foldBitmaskMul(LHS, RHS, Builder))
3695+
return Res;
3696+
3697+
return nullptr;
3698+
}
3699+
3700+
Value *InstCombinerImpl::reassociateDisjointOr(Value *LHS, Value *RHS) {
3701+
3702+
Value *X, *Y;
3703+
if (match(RHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3704+
if (Value *Res = foldDisjointOr(LHS, X))
3705+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3706+
if (Value *Res = foldDisjointOr(LHS, Y))
3707+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3708+
}
3709+
3710+
if (match(LHS, m_OneUse(m_DisjointOr(m_Value(X), m_Value(Y))))) {
3711+
if (Value *Res = foldDisjointOr(X, RHS))
3712+
return Builder.CreateOr(Res, Y, "", /*IsDisjoint=*/true);
3713+
if (Value *Res = foldDisjointOr(Y, RHS))
3714+
return Builder.CreateOr(Res, X, "", /*IsDisjoint=*/true);
3715+
}
3716+
3717+
return nullptr;
3718+
}
3719+
36623720
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
36633721
// here. We should standardize that construct where it is needed or choose some
36643722
// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,28 +3799,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
37413799
/*NSW=*/true, /*NUW=*/true))
37423800
return R;
37433801

3744-
// (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745-
// This also accepts the equivalent select form of (A & N) * C
3746-
// expressions i.e. !(A & N) ? 0 : N * C)
3747-
auto Decomp1 = matchBitmaskMul(I.getOperand(1));
3748-
if (Decomp1) {
3749-
auto Decomp0 = matchBitmaskMul(I.getOperand(0));
3750-
if (Decomp0 && Decomp0->X == Decomp1->X &&
3751-
(Decomp0->Mask & Decomp1->Mask).isZero() &&
3752-
Decomp0->Factor == Decomp1->Factor) {
3753-
3754-
Value *NewAnd = Builder.CreateAnd(
3755-
Decomp0->X, ConstantInt::get(Decomp0->X->getType(),
3756-
(Decomp0->Mask + Decomp1->Mask)));
3757-
3758-
auto *Combined = BinaryOperator::CreateMul(
3759-
NewAnd, ConstantInt::get(NewAnd->getType(), Decomp1->Factor));
3760-
3761-
Combined->setHasNoUnsignedWrap(Decomp0->NUW && Decomp1->NUW);
3762-
Combined->setHasNoSignedWrap(Decomp0->NSW && Decomp1->NSW);
3763-
return Combined;
3764-
}
3765-
}
3802+
if (Value *Res = foldBitmaskMul(I.getOperand(0), I.getOperand(1), Builder))
3803+
return replaceInstUsesWith(I, Res);
3804+
3805+
if (Value *Res = reassociateDisjointOr(I.getOperand(0), I.getOperand(1)))
3806+
return replaceInstUsesWith(I, Res);
37663807
}
37673808

37683809
Value *X, *Y;

llvm/lib/Transforms/InstCombine/InstCombineInternal.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,6 +439,10 @@ class LLVM_LIBRARY_VISIBILITY InstCombinerImpl final
439439
Value *reassociateBooleanAndOr(Value *LHS, Value *X, Value *Y, Instruction &I,
440440
bool IsAnd, bool RHSIsLogical);
441441

442+
Value *foldDisjointOr(Value *LHS, Value *RHS);
443+
444+
Value *reassociateDisjointOr(Value *LHS, Value *RHS);
445+
442446
Instruction *
443447
canonicalizeConditionalNegationViaMathToSelect(BinaryOperator &i);
444448

0 commit comments

Comments
 (0)