Skip to content

Commit d2502eb

Browse files
committed
[KnownBits] Add support for nuw/nsw on shifts
Implement precise nuw/nsw support in the KnownBits implementation, replacing the rather crude handling in ValueTracking. Differential Revision: https://reviews.llvm.org/D151208
1 parent 660b3c8 commit d2502eb

File tree

6 files changed

+92
-29
lines changed

6 files changed

+92
-29
lines changed

llvm/include/llvm/Support/KnownBits.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,7 +382,8 @@ struct KnownBits {
382382

383383
/// Compute known bits for shl(LHS, RHS).
384384
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.
385-
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS);
385+
static KnownBits shl(const KnownBits &LHS, const KnownBits &RHS,
386+
bool NUW = false, bool NSW = false);
386387

387388
/// Compute known bits for lshr(LHS, RHS).
388389
/// NOTE: RHS (shift amount) bitwidth doesn't need to be the same as LHS.

llvm/lib/Analysis/ValueTracking.cpp

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1353,20 +1353,10 @@ static void computeKnownBitsFromOperator(const Operator *I,
13531353
break;
13541354
}
13551355
case Instruction::Shl: {
1356+
bool NUW = Q.IIQ.hasNoUnsignedWrap(cast<OverflowingBinaryOperator>(I));
13561357
bool NSW = Q.IIQ.hasNoSignedWrap(cast<OverflowingBinaryOperator>(I));
1357-
auto KF = [NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
1358-
KnownBits Result = KnownBits::shl(KnownVal, KnownAmt);
1359-
// If this shift has "nsw" keyword, then the result is either a poison
1360-
// value or has the same sign bit as the first operand.
1361-
if (NSW) {
1362-
if (KnownVal.Zero.isSignBitSet())
1363-
Result.Zero.setSignBit();
1364-
if (KnownVal.One.isSignBitSet())
1365-
Result.One.setSignBit();
1366-
if (Result.hasConflict())
1367-
Result.setAllZero();
1368-
}
1369-
return Result;
1358+
auto KF = [NUW, NSW](const KnownBits &KnownVal, const KnownBits &KnownAmt) {
1359+
return KnownBits::shl(KnownVal, KnownAmt, NUW, NSW);
13701360
};
13711361
computeKnownBitsFromShiftOperator(I, DemandedElts, Known, Known2, Depth, Q,
13721362
KF);

llvm/lib/Support/KnownBits.cpp

Lines changed: 50 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -164,21 +164,51 @@ KnownBits KnownBits::smin(const KnownBits &LHS, const KnownBits &RHS) {
164164
return Flip(umax(Flip(LHS), Flip(RHS)));
165165
}
166166

167-
KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
167+
KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS, bool NUW,
168+
bool NSW) {
168169
unsigned BitWidth = LHS.getBitWidth();
169-
KnownBits Known(BitWidth);
170+
auto ShiftByConst = [&](const KnownBits &LHS,
171+
uint64_t ShiftAmt) -> std::optional<KnownBits> {
172+
KnownBits Known;
173+
Known.Zero = LHS.Zero << ShiftAmt;
174+
Known.Zero.setLowBits(ShiftAmt);
175+
Known.One = LHS.One << ShiftAmt;
176+
if ((!NUW && !NSW) || ShiftAmt == 0)
177+
return Known;
178+
179+
KnownBits ShiftedOutBits = LHS.extractBits(ShiftAmt, BitWidth - ShiftAmt);
180+
if (NUW && !ShiftedOutBits.One.isZero())
181+
// One bit has been shifted out.
182+
return std::nullopt;
183+
if (NSW) {
184+
if (!ShiftedOutBits.Zero.isZero() && !ShiftedOutBits.One.isZero())
185+
// Both zeros and ones have been shifted out.
186+
return std::nullopt;
187+
if (NUW || !ShiftedOutBits.Zero.isZero()) {
188+
if (Known.isNegative())
189+
// Zero bit has been shifted out, but result sign is negative.
190+
return std::nullopt;
191+
Known.makeNonNegative();
192+
} else if (!ShiftedOutBits.One.isZero()) {
193+
if (Known.isNonNegative())
194+
// One bit has been shifted out, but result sign is negative.
195+
return std::nullopt;
196+
Known.makeNegative();
197+
}
198+
}
199+
return Known;
200+
};
170201

171202
// If the shift amount is a valid constant then transform LHS directly.
172203
if (RHS.isConstant() && RHS.getConstant().ult(BitWidth)) {
173-
unsigned Shift = RHS.getConstant().getZExtValue();
174-
Known = LHS;
175-
Known.Zero <<= Shift;
176-
Known.One <<= Shift;
177-
// Low bits are known zero.
178-
Known.Zero.setLowBits(Shift);
204+
if (auto Res = ShiftByConst(LHS, RHS.getConstant().getZExtValue()))
205+
return *Res;
206+
KnownBits Known(BitWidth);
207+
Known.setAllZero();
179208
return Known;
180209
}
181210

211+
KnownBits Known(BitWidth);
182212
APInt MinShiftAmount = RHS.getMinValue();
183213
if (MinShiftAmount.uge(BitWidth)) {
184214
// Always poison. Return zero because we don't like returning conflict.
@@ -193,6 +223,8 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
193223
MinTrailingZeros += MinShiftAmount.getZExtValue();
194224
MinTrailingZeros = std::min(MinTrailingZeros, BitWidth);
195225
Known.Zero.setLowBits(MinTrailingZeros);
226+
if (NUW && NSW && !MinShiftAmount.isZero())
227+
Known.makeNonNegative();
196228
return Known;
197229
}
198230

@@ -210,15 +242,20 @@ KnownBits KnownBits::shl(const KnownBits &LHS, const KnownBits &RHS) {
210242
if ((ShiftAmtZeroMask & ShiftAmt) != ShiftAmt ||
211243
(ShiftAmtOneMask | ShiftAmt) != ShiftAmt)
212244
continue;
213-
KnownBits SpecificShift;
214-
SpecificShift.Zero = LHS.Zero << ShiftAmt;
215-
SpecificShift.Zero.setLowBits(ShiftAmt);
216-
SpecificShift.One = LHS.One << ShiftAmt;
217-
Known = Known.intersectWith(SpecificShift);
245+
auto Res = ShiftByConst(LHS, ShiftAmt);
246+
if (!Res)
247+
// All larger shift amounts will overflow as well.
248+
break;
249+
Known = Known.intersectWith(*Res);
218250
if (Known.isUnknown())
219251
break;
220252
}
221253

254+
// All shift amounts may result in poison.
255+
if (Known.hasConflict()) {
256+
assert((NUW || NSW) && "Can only happen with nowrap flags");
257+
Known.setAllZero();
258+
}
222259
return Known;
223260
}
224261

llvm/test/Transforms/ConstraintElimination/geps-unsigned-predicates.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,7 +513,7 @@ define void @test.ult.gep.shl(ptr readonly %src, ptr readnone %max, i8 %idx) {
513513
; CHECK-NEXT: [[IDX_SHL_1:%.*]] = shl nuw nsw i8 [[IDX]], 1
514514
; CHECK-NEXT: [[ADD_PTR_SHL_1:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_1]]
515515
; CHECK-NEXT: [[C_MAX_0:%.*]] = icmp ult ptr [[ADD_PTR_SHL_1]], [[MAX]]
516-
; CHECK-NEXT: call void @use(i1 [[C_MAX_0]])
516+
; CHECK-NEXT: call void @use(i1 true)
517517
; CHECK-NEXT: [[IDX_SHL_2:%.*]] = shl nuw i8 [[IDX]], 2
518518
; CHECK-NEXT: [[ADD_PTR_SHL_2:%.*]] = getelementptr inbounds i32, ptr [[SRC]], i8 [[IDX_SHL_2]]
519519
; CHECK-NEXT: [[C_MAX_1:%.*]] = icmp ult ptr [[ADD_PTR_SHL_2]], [[MAX]]

llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ define void @test_array_load2_store2(i32 %C, i32 %D) #1 {
8282
; CHECK-NEXT: [[ARRAYIDX3:%.*]] = getelementptr inbounds [1024 x i32], ptr @CD, i64 0, i64 [[OR]]
8383
; CHECK-NEXT: store i32 [[MUL]], ptr [[ARRAYIDX3]], align 4
8484
; CHECK-NEXT: [[INDVARS_IV_NEXT]] = add nuw nsw i64 [[INDVARS_IV]], 2
85-
; CHECK-NEXT: [[CMP:%.*]] = icmp slt i64 [[INDVARS_IV]], 1022
85+
; CHECK-NEXT: [[CMP:%.*]] = icmp ult i64 [[INDVARS_IV]], 1022
8686
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY]], label [[FOR_END]], !llvm.loop [[LOOP3:![0-9]+]]
8787
; CHECK: for.end:
8888
; CHECK-NEXT: ret void

llvm/unittests/Support/KnownBitsTest.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,41 @@ TEST(KnownBitsTest, BinaryExhaustive) {
343343
return std::nullopt;
344344
return N1.shl(N2);
345345
});
346+
testBinaryOpExhaustive(
347+
[](const KnownBits &Known1, const KnownBits &Known2) {
348+
return KnownBits::shl(Known1, Known2, /* NUW */ true);
349+
},
350+
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
351+
bool Overflow;
352+
APInt Res = N1.ushl_ov(N2, Overflow);
353+
if (Overflow)
354+
return std::nullopt;
355+
return Res;
356+
});
357+
testBinaryOpExhaustive(
358+
[](const KnownBits &Known1, const KnownBits &Known2) {
359+
return KnownBits::shl(Known1, Known2, /* NUW */ false, /* NSW */ true);
360+
},
361+
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
362+
bool Overflow;
363+
APInt Res = N1.sshl_ov(N2, Overflow);
364+
if (Overflow)
365+
return std::nullopt;
366+
return Res;
367+
});
368+
testBinaryOpExhaustive(
369+
[](const KnownBits &Known1, const KnownBits &Known2) {
370+
return KnownBits::shl(Known1, Known2, /* NUW */ true, /* NSW */ true);
371+
},
372+
[](const APInt &N1, const APInt &N2) -> std::optional<APInt> {
373+
bool OverflowUnsigned, OverflowSigned;
374+
APInt Res = N1.ushl_ov(N2, OverflowUnsigned);
375+
(void)N1.sshl_ov(N2, OverflowSigned);
376+
if (OverflowUnsigned || OverflowSigned)
377+
return std::nullopt;
378+
return Res;
379+
});
380+
346381
testBinaryOpExhaustive(
347382
[](const KnownBits &Known1, const KnownBits &Known2) {
348383
return KnownBits::lshr(Known1, Known2);

0 commit comments

Comments
 (0)