Skip to content

Commit 2ebe573

Browse files
committed
[InstCombine] Dropping redundant masking before left-shift [2/5] (PR42563)
Summary: If we have some pattern that leaves only some low bits set, and then performs left-shift of those bits, if none of the bits that are left after the final shift are modified by the mask, we can omit the mask. There are many variants to this pattern: c. `(x & (-1 >> MaskShAmt)) << ShiftShAmt` All these patterns can be simplified to just: `x << ShiftShAmt` iff: c. `(ShiftShAmt-MaskShAmt) s>= 0` (i.e. `ShiftShAmt u>= MaskShAmt`) alive proofs: c: https://rise4fun.com/Alive/RgJh For now let's start with patterns where both shift amounts are variable, with trivial constant "offset" between them, since i believe this is both simplest to handle and i think this is most common. But again, there are likely other variants where we could use ValueTracking/ConstantRange to handle more cases. https://bugs.llvm.org/show_bug.cgi?id=42563 Differential Revision: https://reviews.llvm.org/D64517 llvm-svn: 366537
1 parent 4422a16 commit 2ebe573

File tree

2 files changed

+42
-26
lines changed

2 files changed

+42
-26
lines changed

llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,12 @@ reassociateShiftAmtsOfTwoSameDirectionShifts(BinaryOperator *Sh0,
7272
// There are many variants to this pattern:
7373
// a) (x & ((1 << MaskShAmt) - 1)) << ShiftShAmt
7474
// b) (x & (~(-1 << MaskShAmt))) << ShiftShAmt
75+
// c) (x & (-1 >> MaskShAmt)) << ShiftShAmt
7576
// All these patterns can be simplified to just:
7677
// x << ShiftShAmt
7778
// iff:
7879
// a,b) (MaskShAmt+ShiftShAmt) u>= bitwidth(x)
80+
// c) (ShiftShAmt-MaskShAmt) s>= 0 (i.e. ShiftShAmt u>= MaskShAmt)
7981
static Instruction *
8082
dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
8183
const SimplifyQuery &SQ) {
@@ -91,24 +93,38 @@ dropRedundantMaskingOfLeftShiftInput(BinaryOperator *OuterShift,
9193
auto MaskA = m_Add(m_Shl(m_One(), m_Value(MaskShAmt)), m_AllOnes());
9294
// (~(-1 << maskNbits))
9395
auto MaskB = m_Xor(m_Shl(m_AllOnes(), m_Value(MaskShAmt)), m_AllOnes());
96+
// (-1 >> MaskShAmt)
97+
auto MaskC = m_Shr(m_AllOnes(), m_Value(MaskShAmt));
9498

9599
Value *X;
96-
if (!match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X))))
97-
return nullptr;
98-
99-
// Can we simplify (MaskShAmt+ShiftShAmt) ?
100-
Value *SumOfShAmts =
101-
SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
102-
SQ.getWithInstruction(OuterShift));
103-
if (!SumOfShAmts)
104-
return nullptr; // Did not simplify.
105-
// Is the total shift amount *not* smaller than the bit width?
106-
// FIXME: could also rely on ConstantRange.
107-
unsigned BitWidth = X->getType()->getScalarSizeInBits();
108-
if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
109-
APInt(BitWidth, BitWidth))))
110-
return nullptr;
111-
// All good, we can do this fold.
100+
if (match(Masked, m_c_And(m_CombineOr(MaskA, MaskB), m_Value(X)))) {
101+
// Can we simplify (MaskShAmt+ShiftShAmt) ?
102+
Value *SumOfShAmts =
103+
SimplifyAddInst(MaskShAmt, ShiftShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
104+
SQ.getWithInstruction(OuterShift));
105+
if (!SumOfShAmts)
106+
return nullptr; // Did not simplify.
107+
// Is the total shift amount *not* smaller than the bit width?
108+
// FIXME: could also rely on ConstantRange.
109+
unsigned BitWidth = X->getType()->getScalarSizeInBits();
110+
if (!match(SumOfShAmts, m_SpecificInt_ICMP(ICmpInst::Predicate::ICMP_UGE,
111+
APInt(BitWidth, BitWidth))))
112+
return nullptr;
113+
// All good, we can do this fold.
114+
} else if (match(Masked, m_c_And(MaskC, m_Value(X)))) {
115+
// Can we simplify (ShiftShAmt-MaskShAmt) ?
116+
Value *ShAmtsDiff =
117+
SimplifySubInst(ShiftShAmt, MaskShAmt, /*IsNSW=*/false, /*IsNUW=*/false,
118+
SQ.getWithInstruction(OuterShift));
119+
if (!ShAmtsDiff)
120+
return nullptr; // Did not simplify.
121+
// Is the difference non-negative? (is ShiftShAmt u>= MaskShAmt ?)
122+
// FIXME: could also rely on ConstantRange.
123+
if (!match(ShAmtsDiff, m_NonNegative()))
124+
return nullptr;
125+
// All good, we can do this fold.
126+
} else
127+
return nullptr; // Don't know anything about this pattern.
112128

113129
// No 'NUW'/'NSW'!
114130
// We no longer know that we won't shift-out non-0 bits.

llvm/test/Transforms/InstCombine/redundant-left-shift-input-masking-variant-c.ll

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ define i32 @t0_basic(i32 %x, i32 %nbits) {
2121
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
2222
; CHECK-NEXT: call void @use32(i32 [[T0]])
2323
; CHECK-NEXT: call void @use32(i32 [[T1]])
24-
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
24+
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
2525
; CHECK-NEXT: ret i32 [[T2]]
2626
;
2727
%t0 = lshr i32 -1, %nbits
@@ -40,7 +40,7 @@ define i32 @t1_bigger_shift(i32 %x, i32 %nbits) {
4040
; CHECK-NEXT: call void @use32(i32 [[T0]])
4141
; CHECK-NEXT: call void @use32(i32 [[T1]])
4242
; CHECK-NEXT: call void @use32(i32 [[T2]])
43-
; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[T2]]
43+
; CHECK-NEXT: [[T3:%.*]] = shl i32 [[X]], [[T2]]
4444
; CHECK-NEXT: ret i32 [[T3]]
4545
;
4646
%t0 = lshr i32 -1, %nbits
@@ -65,7 +65,7 @@ define <3 x i32> @t2_vec_splat(<3 x i32> %x, <3 x i32> %nbits) {
6565
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
6666
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
6767
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
68-
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
68+
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
6969
; CHECK-NEXT: ret <3 x i32> [[T3]]
7070
;
7171
%t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
@@ -86,7 +86,7 @@ define <3 x i32> @t3_vec_nonsplat(<3 x i32> %x, <3 x i32> %nbits) {
8686
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
8787
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
8888
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
89-
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
89+
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
9090
; CHECK-NEXT: ret <3 x i32> [[T3]]
9191
;
9292
%t0 = lshr <3 x i32> <i32 -1, i32 -1, i32 -1>, %nbits
@@ -107,7 +107,7 @@ define <3 x i32> @t4_vec_undef(<3 x i32> %x, <3 x i32> %nbits) {
107107
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T0]])
108108
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T1]])
109109
; CHECK-NEXT: call void @use3xi32(<3 x i32> [[T2]])
110-
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[T1]], [[T2]]
110+
; CHECK-NEXT: [[T3:%.*]] = shl <3 x i32> [[X]], [[T2]]
111111
; CHECK-NEXT: ret <3 x i32> [[T3]]
112112
;
113113
%t0 = lshr <3 x i32> <i32 -1, i32 undef, i32 -1>, %nbits
@@ -131,7 +131,7 @@ define i32 @t5_commutativity0(i32 %nbits) {
131131
; CHECK-NEXT: [[T1:%.*]] = and i32 [[X]], [[T0]]
132132
; CHECK-NEXT: call void @use32(i32 [[T0]])
133133
; CHECK-NEXT: call void @use32(i32 [[T1]])
134-
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[T1]], [[NBITS]]
134+
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
135135
; CHECK-NEXT: ret i32 [[T2]]
136136
;
137137
%x = call i32 @gen32()
@@ -151,7 +151,7 @@ define i32 @t6_commutativity1(i32 %nbits0, i32 %nbits1) {
151151
; CHECK-NEXT: call void @use32(i32 [[T0]])
152152
; CHECK-NEXT: call void @use32(i32 [[T1]])
153153
; CHECK-NEXT: call void @use32(i32 [[T2]])
154-
; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T2]], [[NBITS0]]
154+
; CHECK-NEXT: [[T3:%.*]] = shl i32 [[T1]], [[NBITS0]]
155155
; CHECK-NEXT: ret i32 [[T3]]
156156
;
157157
%t0 = lshr i32 -1, %nbits0
@@ -192,7 +192,7 @@ define i32 @t8_nuw(i32 %x, i32 %nbits) {
192192
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
193193
; CHECK-NEXT: call void @use32(i32 [[T0]])
194194
; CHECK-NEXT: call void @use32(i32 [[T1]])
195-
; CHECK-NEXT: [[T2:%.*]] = shl nuw i32 [[T1]], [[NBITS]]
195+
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
196196
; CHECK-NEXT: ret i32 [[T2]]
197197
;
198198
%t0 = lshr i32 -1, %nbits
@@ -209,7 +209,7 @@ define i32 @t9_nsw(i32 %x, i32 %nbits) {
209209
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
210210
; CHECK-NEXT: call void @use32(i32 [[T0]])
211211
; CHECK-NEXT: call void @use32(i32 [[T1]])
212-
; CHECK-NEXT: [[T2:%.*]] = shl nsw i32 [[T1]], [[NBITS]]
212+
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
213213
; CHECK-NEXT: ret i32 [[T2]]
214214
;
215215
%t0 = lshr i32 -1, %nbits
@@ -226,7 +226,7 @@ define i32 @t10_nuw_nsw(i32 %x, i32 %nbits) {
226226
; CHECK-NEXT: [[T1:%.*]] = and i32 [[T0]], [[X:%.*]]
227227
; CHECK-NEXT: call void @use32(i32 [[T0]])
228228
; CHECK-NEXT: call void @use32(i32 [[T1]])
229-
; CHECK-NEXT: [[T2:%.*]] = shl nuw nsw i32 [[T1]], [[NBITS]]
229+
; CHECK-NEXT: [[T2:%.*]] = shl i32 [[X]], [[NBITS]]
230230
; CHECK-NEXT: ret i32 [[T2]]
231231
;
232232
%t0 = lshr i32 -1, %nbits

0 commit comments

Comments
 (0)