diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp index 550f095b26ba4..b0b1301cd2580 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp @@ -978,6 +978,47 @@ Instruction *InstCombinerImpl::foldLShrOverflowBit(BinaryOperator &I) { return new ZExtInst(Overflow, Ty); } +/// If the operand of a zext-ed left shift \p V is a logically right-shifted +/// value, try to fold the opposing shifts. +static Instruction *foldShrThroughZExtedShl(Type *DestTy, Value *V, + unsigned ShlAmt, + InstCombinerImpl &IC, + const DataLayout &DL) { + auto *I = dyn_cast(V); + if (!I) + return nullptr; + + // Dig through operations until the first shift. + while (!I->isShift()) + if (!match(I, m_BinOp(m_OneUse(m_Instruction(I)), m_Constant()))) + return nullptr; + + // Fold only if the inner shift is a logical right-shift. + uint64_t InnerShrAmt; + if (!match(I, m_LShr(m_Value(), m_ConstantInt(InnerShrAmt)))) + return nullptr; + + if (InnerShrAmt >= ShlAmt) { + const uint64_t ReducedShrAmt = InnerShrAmt - ShlAmt; + if (!canEvaluateShifted(V, ReducedShrAmt, /*IsLeftShift=*/false, IC, + nullptr)) + return nullptr; + Value *NewInner = + getShiftedValue(V, ReducedShrAmt, /*isLeftShift=*/false, IC, DL); + return new ZExtInst(NewInner, DestTy); + } + + if (!canEvaluateShifted(V, InnerShrAmt, /*IsLeftShift=*/true, IC, nullptr)) + return nullptr; + + const uint64_t ReducedShlAmt = ShlAmt - InnerShrAmt; + Value *NewInner = + getShiftedValue(V, InnerShrAmt, /*isLeftShift=*/true, IC, DL); + Value *NewZExt = IC.Builder.CreateZExt(NewInner, DestTy); + return BinaryOperator::CreateShl(NewZExt, + ConstantInt::get(DestTy, ReducedShlAmt)); +} + // Try to set nuw/nsw flags on shl or exact flag on lshr/ashr using knownbits. static bool setShiftFlags(BinaryOperator &I, const SimplifyQuery &Q) { assert(I.isShift() && "Expected a shift as input"); @@ -1062,14 +1103,18 @@ Instruction *InstCombinerImpl::visitShl(BinaryOperator &I) { if (match(Op1, m_APInt(C))) { unsigned ShAmtC = C->getZExtValue(); - // shl (zext X), C --> zext (shl X, C) - // This is only valid if X would have zeros shifted out. Value *X; if (match(Op0, m_OneUse(m_ZExt(m_Value(X))))) { + // shl (zext X), C --> zext (shl X, C) + // This is only valid if X would have zeros shifted out. unsigned SrcWidth = X->getType()->getScalarSizeInBits(); if (ShAmtC < SrcWidth && MaskedValueIsZero(X, APInt::getHighBitsSet(SrcWidth, ShAmtC), &I)) return new ZExtInst(Builder.CreateShl(X, ShAmtC), Ty); + + // Otherwise, try to cancel the outer shl with a lshr inside the zext. + if (Instruction *V = foldShrThroughZExtedShl(Ty, X, ShAmtC, *this, DL)) + return V; } // (X >> C) << C --> X & (-1 << C) diff --git a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll index 5224d75a157d5..8330fd09090c8 100644 --- a/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll +++ b/llvm/test/Analysis/ValueTracking/numsignbits-shl.ll @@ -101,9 +101,9 @@ define void @numsignbits_shl_zext_extended_bits_remains(i8 %x) { define void @numsignbits_shl_zext_all_bits_shifted_out(i8 %x) { ; CHECK-LABEL: define void @numsignbits_shl_zext_all_bits_shifted_out( ; CHECK-SAME: i8 [[X:%.*]]) { -; CHECK-NEXT: [[ASHR:%.*]] = lshr i8 [[X]], 5 -; CHECK-NEXT: [[ZEXT:%.*]] = zext nneg i8 [[ASHR]] to i16 -; CHECK-NEXT: [[NSB1:%.*]] = shl i16 [[ZEXT]], 14 +; CHECK-NEXT: [[ASHR:%.*]] = and i8 [[X]], 96 +; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i8 [[ASHR]] to i16 +; CHECK-NEXT: [[NSB1:%.*]] = shl nuw i16 [[TMP1]], 9 ; CHECK-NEXT: [[AND14:%.*]] = and i16 [[NSB1]], 16384 ; CHECK-NEXT: [[ADD14:%.*]] = add i16 [[AND14]], [[NSB1]] ; CHECK-NEXT: call void @escape(i16 [[ADD14]]) diff --git a/llvm/test/Transforms/InstCombine/iX-ext-split.ll b/llvm/test/Transforms/InstCombine/iX-ext-split.ll index fc804df0e4bec..b8e056725f122 100644 --- a/llvm/test/Transforms/InstCombine/iX-ext-split.ll +++ b/llvm/test/Transforms/InstCombine/iX-ext-split.ll @@ -197,9 +197,9 @@ define i128 @i128_ext_split_neg4(i32 %x) { ; CHECK-NEXT: [[ENTRY:.*:]] ; CHECK-NEXT: [[LOWERSRC:%.*]] = sext i32 [[X]] to i64 ; CHECK-NEXT: [[LO:%.*]] = zext i64 [[LOWERSRC]] to i128 -; CHECK-NEXT: [[SIGN:%.*]] = lshr i32 [[X]], 31 -; CHECK-NEXT: [[WIDEN:%.*]] = zext nneg i32 [[SIGN]] to i128 -; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[WIDEN]], 64 +; CHECK-NEXT: [[SIGN:%.*]] = and i32 [[X]], -2147483648 +; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[SIGN]] to i128 +; CHECK-NEXT: [[HI:%.*]] = shl nuw nsw i128 [[TMP0]], 33 ; CHECK-NEXT: [[RES:%.*]] = or disjoint i128 [[HI]], [[LO]] ; CHECK-NEXT: ret i128 [[RES]] ; diff --git a/llvm/test/Transforms/InstCombine/shifts-around-zext.ll b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll new file mode 100644 index 0000000000000..517783fcbcb5c --- /dev/null +++ b/llvm/test/Transforms/InstCombine/shifts-around-zext.ll @@ -0,0 +1,69 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -S -passes=instcombine %s | FileCheck %s + +define i64 @simple(i32 %x) { +; CHECK-LABEL: define i64 @simple( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[LSHR:%.*]] = and i32 [[X]], -256 +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[LSHR]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 24 +; CHECK-NEXT: ret i64 [[SHL]] +; + %lshr = lshr i32 %x, 8 + %zext = zext i32 %lshr to i64 + %shl = shl i64 %zext, 32 + ret i64 %shl +} + +;; u0xff0 = 4080 +define i64 @masked(i32 %x) { +; CHECK-LABEL: define i64 @masked( +; CHECK-SAME: i32 [[X:%.*]]) { +; CHECK-NEXT: [[MASK:%.*]] = and i32 [[X]], 4080 +; CHECK-NEXT: [[TMP1:%.*]] = zext nneg i32 [[MASK]] to i64 +; CHECK-NEXT: [[SHL:%.*]] = shl nuw nsw i64 [[TMP1]], 44 +; CHECK-NEXT: ret i64 [[SHL]] +; + %lshr = lshr i32 %x, 4 + %mask = and i32 %lshr, u0xff + %zext = zext i32 %mask to i64 + %shl = shl i64 %zext, 48 + ret i64 %shl +} + +define i64 @combine(i32 %lower, i32 %upper) { +; CHECK-LABEL: define i64 @combine( +; CHECK-SAME: i32 [[LOWER:%.*]], i32 [[UPPER:%.*]]) { +; CHECK-NEXT: [[BASE:%.*]] = zext i32 [[LOWER]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = zext i32 [[UPPER]] to i64 +; CHECK-NEXT: [[TMP2:%.*]] = shl nuw i64 [[TMP1]], 32 +; CHECK-NEXT: [[O_3:%.*]] = or disjoint i64 [[TMP2]], [[BASE]] +; CHECK-NEXT: ret i64 [[O_3]] +; + %base = zext i32 %lower to i64 + + %u.0 = and i32 %upper, u0xff + %z.0 = zext i32 %u.0 to i64 + %s.0 = shl i64 %z.0, 32 + %o.0 = or i64 %base, %s.0 + + %r.1 = lshr i32 %upper, 8 + %u.1 = and i32 %r.1, u0xff + %z.1 = zext i32 %u.1 to i64 + %s.1 = shl i64 %z.1, 40 + %o.1 = or i64 %o.0, %s.1 + + %r.2 = lshr i32 %upper, 16 + %u.2 = and i32 %r.2, u0xff + %z.2 = zext i32 %u.2 to i64 + %s.2 = shl i64 %z.2, 48 + %o.2 = or i64 %o.1, %s.2 + + %r.3 = lshr i32 %upper, 24 + %u.3 = and i32 %r.3, u0xff + %z.3 = zext i32 %u.3 to i64 + %s.3 = shl i64 %z.3, 56 + %o.3 = or i64 %o.2, %s.3 + + ret i64 %o.3 +}