Skip to content

[InstCombine] Fold lshr -> zext -> shl patterns #147737

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) {
return new ZExtInst(Overflow, Ty);
}

/// If the operand of a zext-ed left shift \p V is a logically right-shifted
/// value, try to fold the opposing shifts.
static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V,
unsigned ShlAmt,
InstCombinerImpl &IC,
const DataLayout &DL) {
auto *I = dyn_cast<Instruction>(V);
if (!I)
return nullptr;

// Dig through operations until the first shift.
while (!I->isShift())
if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant())))
return nullptr;

// Fold only if the inner shift is a logical right-shift.
uint64_t InnerShrAmt;
if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt))))
return nullptr;

if (InnerShrAmt >= ShlAmt) {
const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt;
if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC,
nullptr))
return nullptr;
Value *NewInner =
getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL);
return new ZExtInst(NewInner, DestTy);
}

if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr))
return nullptr;

const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt;
Value *NewInner =
getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL);
Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy);
return BinaryOperator::CreateShl(NewZExt,
ConstantInt::get(DestTy, ReducedShlAmt));
}

// Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits.
static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) {
assert(I.isShift() && "Expected a shift as input");
Expand Down Expand Up @@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) {
if (match(Op1, m_APInt(C))) {
unsigned ShAmtC = C->getZExtValue();

// shl (zext X), C --> zext (shl X, C)
// This is only valid if X would have zeros shifted out.
Value *X;
if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) {
// shl (zext X), C --> zext (shl X, C)
// This is only valid if X would have zeros shifted out.
unsigned SrcWidth = X->getType()->getScalarSizeInBits();
if (ShAmtC < SrcWidth &&
MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I))
return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty);

// Otherwise, try to cancel the outer shl with a lshr inside the zext.
if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL))
return V;
}

// (X >> C) << C --> X & (-1 << C)
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Analysis/ValueTracking/numsignbits-shl.ll
Original file line number Diff line number Diff line change
Expand Up @@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) {
define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) {
; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out(
; CHECK-SAME: i8 [[X:%.*]]) {
; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5
; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16
; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14
; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96
; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16
; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9
; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384
; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]]
; CHECK-NEXT: call void @escape(i16 [[ADD14]])
Expand Down
6 changes: 3 additions & 3 deletions llvm/test/Transforms/InstCombine/iX-ext-split.ll
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) {
; CHECK-NEXT: [[ENTRY:.*:]]
; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64
; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128
; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31
; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128
; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64
; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648
; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128
; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33
; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]]
; CHECK-NEXT: ret i128 [[RES]]
;
Expand Down
69 changes: 69 additions & 0 deletions llvm/test/Transforms/InstCombine/shifts-around-zext.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
; RUN: opt -S -passes=instcombine %s | FileCheck %s

define i64 @simple(i32 %x) {
; CHECK-LABEL: define i64 @simple(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256
; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24
; CHECK-NEXT: ret i64 [[SHL]]
;
%lshr = lshr i32 %x, 8
%zext = zext i32 %lshr to i64
%shl = shl i64 %zext, 32
ret i64 %shl
}

;; u0xff0 = 4080
define i64 @masked(i32 %x) {
; CHECK-LABEL: define i64 @masked(
; CHECK-SAME: i32 [[X:%.*]]) {
; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080
; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64
; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44
; CHECK-NEXT: ret i64 [[SHL]]
;
%lshr = lshr i32 %x, 4
%mask = and i32 %lshr, u0xff
%zext = zext i32 %mask to i64
%shl = shl i64 %zext, 48
ret i64 %shl
}

define i64 @combine(i32 %lower, i32 %upper) {
; CHECK-LABEL: define i64 @combine(
; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) {
; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64
; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[UPPER]] to i64
; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32
; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]]
; CHECK-NEXT: ret i64 [[O_3]]
;
%base = zext i32 %lower to i64

%u.0 = and i32 %upper, u0xff
%z.0 = zext i32 %u.0 to i64
%s.0 = shl i64 %z.0, 32
%o.0 = or i64 %base, %s.0

%r.1 = lshr i32 %upper, 8
%u.1 = and i32 %r.1, u0xff
%z.1 = zext i32 %u.1 to i64
%s.1 = shl i64 %z.1, 40
%o.1 = or i64 %o.0, %s.1

%r.2 = lshr i32 %upper, 16
%u.2 = and i32 %r.2, u0xff
%z.2 = zext i32 %u.2 to i64
%s.2 = shl i64 %z.2, 48
%o.2 = or i64 %o.1, %s.2

%r.3 = lshr i32 %upper, 24
%u.3 = and i32 %r.3, u0xff
%z.3 = zext i32 %u.3 to i64
%s.3 = shl i64 %z.3, 56
%o.3 = or i64 %o.2, %s.3

ret i64 %o.3
}