@@ -3602,6 +3602,11 @@ struct DecomposedBitMaskMul {
3602
3602
APInt Mask;
3603
3603
bool NUW;
3604
3604
bool NSW;
3605
+
3606
+ bool isCombineableWith (const DecomposedBitMaskMul Other) {
3607
+ return X == Other.X && !Mask.intersects (Other.Mask ) &&
3608
+ Factor == Other.Factor ;
3609
+ }
3605
3610
};
3606
3611
3607
3612
static std::optional<DecomposedBitMaskMul> matchBitmaskMul (Value *V) {
@@ -3659,6 +3664,59 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
3659
3664
return std::nullopt;
3660
3665
}
3661
3666
3667
+ // / (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3668
+ // / This also accepts the equivalent select form of (A & N) * C
3669
+ // / expressions i.e. !(A & N) ? 0 : N * C)
3670
+ static Value *foldBitmaskMul (Value *Op0, Value *Op1,
3671
+ InstCombiner::BuilderTy &Builder) {
3672
+ auto Decomp1 = matchBitmaskMul (Op1);
3673
+ if (!Decomp1)
3674
+ return nullptr ;
3675
+
3676
+ auto Decomp0 = matchBitmaskMul (Op0);
3677
+ if (!Decomp0)
3678
+ return nullptr ;
3679
+
3680
+ if (Decomp0->isCombineableWith (*Decomp1)) {
3681
+ Value *NewAnd = Builder.CreateAnd (
3682
+ Decomp0->X ,
3683
+ ConstantInt::get (Decomp0->X ->getType (), Decomp0->Mask + Decomp1->Mask ));
3684
+
3685
+ return Builder.CreateMul (
3686
+ NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ), " " ,
3687
+ Decomp0->NUW && Decomp1->NUW , Decomp0->NSW && Decomp1->NSW );
3688
+ }
3689
+
3690
+ return nullptr ;
3691
+ }
3692
+
3693
+ Value *InstCombinerImpl::foldDisjointOr (Value *LHS, Value *RHS) {
3694
+ if (Value *Res = foldBitmaskMul (LHS, RHS, Builder))
3695
+ return Res;
3696
+
3697
+ return nullptr ;
3698
+ }
3699
+
3700
+ Value *InstCombinerImpl::reassociateDisjointOr (Value *LHS, Value *RHS) {
3701
+
3702
+ Value *X, *Y;
3703
+ if (match (RHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3704
+ if (Value *Res = foldDisjointOr (LHS, X))
3705
+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3706
+ if (Value *Res = foldDisjointOr (LHS, Y))
3707
+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3708
+ }
3709
+
3710
+ if (match (LHS, m_OneUse (m_DisjointOr (m_Value (X), m_Value (Y))))) {
3711
+ if (Value *Res = foldDisjointOr (X, RHS))
3712
+ return Builder.CreateOr (Res, Y, " " , /* IsDisjoint=*/ true );
3713
+ if (Value *Res = foldDisjointOr (Y, RHS))
3714
+ return Builder.CreateOr (Res, X, " " , /* IsDisjoint=*/ true );
3715
+ }
3716
+
3717
+ return nullptr ;
3718
+ }
3719
+
3662
3720
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
3663
3721
// here. We should standardize that construct where it is needed or choose some
3664
3722
// other way to ensure that commutated variants of patterns are not missed.
@@ -3741,28 +3799,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
3741
3799
/* NSW=*/ true , /* NUW=*/ true ))
3742
3800
return R;
3743
3801
3744
- // (A & N) * C + (A & M) * C -> (A & (N + M)) & C
3745
- // This also accepts the equivalent select form of (A & N) * C
3746
- // expressions i.e. !(A & N) ? 0 : N * C)
3747
- auto Decomp1 = matchBitmaskMul (I.getOperand (1 ));
3748
- if (Decomp1) {
3749
- auto Decomp0 = matchBitmaskMul (I.getOperand (0 ));
3750
- if (Decomp0 && Decomp0->X == Decomp1->X &&
3751
- (Decomp0->Mask & Decomp1->Mask ).isZero () &&
3752
- Decomp0->Factor == Decomp1->Factor ) {
3753
-
3754
- Value *NewAnd = Builder.CreateAnd (
3755
- Decomp0->X , ConstantInt::get (Decomp0->X ->getType (),
3756
- (Decomp0->Mask + Decomp1->Mask )));
3757
-
3758
- auto *Combined = BinaryOperator::CreateMul (
3759
- NewAnd, ConstantInt::get (NewAnd->getType (), Decomp1->Factor ));
3760
-
3761
- Combined->setHasNoUnsignedWrap (Decomp0->NUW && Decomp1->NUW );
3762
- Combined->setHasNoSignedWrap (Decomp0->NSW && Decomp1->NSW );
3763
- return Combined;
3764
- }
3765
- }
3802
+ if (Value *Res = foldBitmaskMul (I.getOperand (0 ), I.getOperand (1 ), Builder))
3803
+ return replaceInstUsesWith (I, Res);
3804
+
3805
+ if (Value *Res = reassociateDisjointOr (I.getOperand (0 ), I.getOperand (1 )))
3806
+ return replaceInstUsesWith (I, Res);
3766
3807
}
3767
3808
3768
3809
Value *X, *Y;
0 commit comments