-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[InstCombine] Fold umul.overflow(x, c1) | (x*c1 > c2) to x > c2/c1 #147327
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-llvm-transforms Author: Marius Kamp (mskamp) ChangesThe motivation of this pattern is to check whether the product of a Unless Alive proof: https://alive2.llvm.org/ce/z/LawTkm Closes #142674 Full diff: https://github.com/llvm/llvm-project/pull/147327.diff 2 Files Affected:
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index dd16cfaeecd45..72da6f44be182 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -3659,6 +3659,32 @@ static std::optional<DecomposedBitMaskMul> matchBitmaskMul(Value *V) {
return std::nullopt;
}
+/// Fold Res, Overflow = (umul.with.overflow x c1); (or Overflow (ugt Res c2))
+/// --> (ugt x (c2/c1)). This code checks whether a multiplication of two
+/// unsigned numbers (one is a constant) is mathematically greater than a
+/// second constant.
+static Value *foldOrUnsignedUMulOverflowICmp(BinaryOperator &I,
+ InstCombiner::BuilderTy &Builder,
+ const DataLayout &DL) {
+ const WithOverflowInst *WO;
+ const Value *WOV;
+ Constant *C1, *C2;
+ if (match(&I, m_c_Or(m_OneUse(m_ExtractValue<1>(
+ m_CombineAnd(m_WithOverflowInst(WO), m_Value(WOV)))),
+ m_OneUse(m_SpecificCmp(
+ ICmpInst::ICMP_UGT,
+ m_OneUse(m_ExtractValue<0>(m_Deferred(WOV))),
+ m_ImmConstant(C2))))) &&
+ WO->getIntrinsicID() == Intrinsic::umul_with_overflow &&
+ match(WO->getRHS(), m_ImmConstant(C1)) && WO->hasNUses(2)) {
+ assert(!C1->isNullValue()); // This should have been folded away.
+ Constant *NewC =
+ ConstantFoldBinaryOpOperands(Instruction::UDiv, C2, C1, DL);
+ return Builder.CreateICmp(ICmpInst::ICMP_UGT, WO->getLHS(), NewC);
+ }
+ return nullptr;
+}
+
// FIXME: We use commutative matchers (m_c_*) for some, but not all, matches
// here. We should standardize that construct where it is needed or choose some
// other way to ensure that commutated variants of patterns are not missed.
@@ -4109,6 +4135,11 @@ Instruction *InstCombinerImpl::visitOr(BinaryOperator &I) {
}
}
+ // Try to fold the pattern "Overflow | icmp pred Res, C2" into a single
+ // comparison instruction for umul.with.overflow.
+ if (Value *R = foldOrUnsignedUMulOverflowICmp(I, Builder, DL))
+ return replaceInstUsesWith(I, R);
+
// (~x) | y --> ~(x & (~y)) iff that gets rid of inversions
if (sinkNotIntoOtherHandOfLogicalOp(I))
return &I;
diff --git a/llvm/test/Transforms/InstCombine/icmp_or_umul_overflow.ll b/llvm/test/Transforms/InstCombine/icmp_or_umul_overflow.ll
new file mode 100644
index 0000000000000..ac900b4e9591e
--- /dev/null
+++ b/llvm/test/Transforms/InstCombine/icmp_or_umul_overflow.ll
@@ -0,0 +1,226 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5
+; RUN: opt -S -passes=instcombine < %s | FileCheck %s
+
+declare void @use.i1(i1 %x)
+declare void @use.i64(i64 %x)
+declare void @use.i64i1({i64, i1} %x)
+
+define i1 @umul_greater_than_or_overflow_const(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt i64 [[IN]], 109802048057794950
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 168)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, -16
+ %ret = or i1 %ovf, %cmp
+ ret i1 %ret
+}
+
+define i1 @umul_greater_than_or_overflow_const_i8(i8 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_i8(
+; CHECK-SAME: i8 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt i8 [[IN]], 10
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = call { i8, i1 } @llvm.umul.with.overflow.i8(i8 %in, i8 24)
+ %mul = extractvalue { i8, i1 } %mwo, 0
+ %ovf = extractvalue { i8, i1 } %mwo, 1
+ %cmp = icmp ugt i8 %mul, -16
+ %ret = or i1 %ovf, %cmp
+ ret i1 %ret
+}
+
+define i1 @umul_greater_than_or_overflow_const_commuted(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_commuted(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt i64 [[IN]], 192153584101141162
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or i1 %cmp, %ovf
+ ret i1 %ret
+}
+
+define i1 @umul_greater_than_or_overflow_const_disjoint(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_disjoint(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt i64 [[IN]], 230584300921369395
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 40)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or disjoint i1 %ovf, %cmp
+ ret i1 %ret
+}
+
+define <2 x i1> @umul_greater_than_or_overflow_const_vector_splat(<2 x i64> %in) {
+; CHECK-LABEL: define <2 x i1> @umul_greater_than_or_overflow_const_vector_splat(
+; CHECK-SAME: <2 x i64> [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt <2 x i64> [[IN]], splat (i64 6477087104532848)
+; CHECK-NEXT: ret <2 x i1> [[TMP6]]
+;
+ %mwo = tail call { <2 x i64>, <2 x i1> } @llvm.umul.with.overflow.v2i64(<2 x i64> %in, <2 x i64> <i64 1424, i64 1424>)
+ %mul = extractvalue { <2 x i64>, <2 x i1> } %mwo, 0
+ %ovf = extractvalue { <2 x i64>, <2 x i1> } %mwo, 1
+ %cmp = icmp ugt <2 x i64> %mul, <i64 9223372036854775800, i64 9223372036854775800>
+ %ret = or <2 x i1> %ovf, %cmp
+ ret <2 x i1> %ret
+}
+
+define <2 x i1> @umul_greater_than_or_overflow_const_vector_non_splat(<2 x i64> %in) {
+; CHECK-LABEL: define <2 x i1> @umul_greater_than_or_overflow_const_vector_non_splat(
+; CHECK-SAME: <2 x i64> [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP6:%.*]] = icmp ugt <2 x i64> [[IN]], <i64 384307168202282291, i64 6477087104532848>
+; CHECK-NEXT: ret <2 x i1> [[TMP6]]
+;
+ %mwo = tail call { <2 x i64>, <2 x i1> } @llvm.umul.with.overflow.v2i64(<2 x i64> %in, <2 x i64> <i64 24, i64 1424>)
+ %mul = extractvalue { <2 x i64>, <2 x i1> } %mwo, 0
+ %ovf = extractvalue { <2 x i64>, <2 x i1> } %mwo, 1
+ %cmp = icmp ugt <2 x i64> %mul, <i64 9223372036854775000, i64 9223372036854775800>
+ %ret = or <2 x i1> %ovf, %cmp
+ ret <2 x i1> %ret
+}
+
+; Negative test
+define i1 @umul_greater_than_and_overflow_const_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_and_overflow_const_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = and i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ult i64 %mul, 9223372036854775800
+ %ret = and i1 %ovf, %cmp
+ ret i1 %ret
+}
+
+; Negative test
+define i1 @umul_less_than_or_overflow_const_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_less_than_or_overflow_const_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ult i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ult i64 %mul, 9223372036854775800
+ %ret = or i1 %ovf, %cmp
+ ret i1 %ret
+}
+
+; Negative test
+define i1 @umul_greater_than_or_overflow_const_multiuse_add_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_multiuse_add_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ugt i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: tail call void @use.i64(i64 [[TMP3]])
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or i1 %ovf, %cmp
+ tail call void @use.i64(i64 %mul)
+ ret i1 %ret
+}
+
+; Negative test
+define i1 @umul_greater_than_or_overflow_const_multiuse_overflow_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_multiuse_overflow_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ugt i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: tail call void @use.i1(i1 [[TMP4]])
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or i1 %ovf, %cmp
+ tail call void @use.i1(i1 %ovf)
+ ret i1 %ret
+}
+
+; Negative test
+define i1 @umul_greater_than_or_overflow_const_multiuse_icmp_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_multiuse_icmp_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ugt i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: tail call void @use.i1(i1 [[TMP5]])
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or i1 %ovf, %cmp
+ tail call void @use.i1(i1 %cmp)
+ ret i1 %ret
+}
+
+; Negative test
+define i1 @umul_greater_than_or_overflow_const_multiuse_umul_call_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_multiuse_umul_call_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: [[TMP2:%.*]] = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 [[IN]], i64 48)
+; CHECK-NEXT: [[TMP3:%.*]] = extractvalue { i64, i1 } [[TMP2]], 0
+; CHECK-NEXT: [[TMP4:%.*]] = extractvalue { i64, i1 } [[TMP2]], 1
+; CHECK-NEXT: [[TMP5:%.*]] = icmp ugt i64 [[TMP3]], 9223372036854775800
+; CHECK-NEXT: [[TMP6:%.*]] = or i1 [[TMP4]], [[TMP5]]
+; CHECK-NEXT: tail call void @use.i64i1({ i64, i1 } [[TMP2]])
+; CHECK-NEXT: ret i1 [[TMP6]]
+;
+ %mwo = tail call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 48)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 9223372036854775800
+ %ret = or i1 %ovf, %cmp
+ tail call void @use.i64i1({ i64, i1 } %mwo)
+ ret i1 %ret
+}
+
+; Negative test. The umul.with.overflow should be folded away before.
+define i1 @umul_greater_than_or_overflow_const_0_negative(i64 %in) {
+; CHECK-LABEL: define i1 @umul_greater_than_or_overflow_const_0_negative(
+; CHECK-SAME: i64 [[IN:%.*]]) {
+; CHECK-NEXT: ret i1 false
+;
+ %mwo = call { i64, i1 } @llvm.umul.with.overflow.i64(i64 %in, i64 0)
+ %mul = extractvalue { i64, i1 } %mwo, 0
+ %ovf = extractvalue { i64, i1 } %mwo, 1
+ %cmp = icmp ugt i64 %mul, 0
+ %ret = or i1 %ovf, %cmp
+ ret i1 %ret
+}
|
a923f76
to
e5ac661
Compare
e5ac661
to
61dd02e
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like this needs a rebase.
61dd02e
to
08a0903
Compare
The motivation of this pattern is to check whether the product of a variable and a constant would be mathematically (i.e., as integer numbers instead of bit vectors) greater than a given constant bound. The pattern appears to occur when compiling several Rust projects (it seems to originate from the `smallvec` crate but I have not checked this further). Unless `c1` is `0`, we can transform this pattern into `x > c2/c1` with all operations working on unsigned integers. Due to undefined behavior when an element of a non-splat vector is `0`, the transform is only implemented for scalars and splat vectors. Alive proof: https://alive2.llvm.org/ce/z/LawTkm Closes llvm#142674
08a0903
to
131b14f
Compare
I've pushed twice for an easier review: The first push comprises the changes, the second push does the rebase. |
m_ExtractValue<0>(m_Deferred(WOV)), | ||
m_APInt(C2))))) && | ||
WO->getIntrinsicID() == Intrinsic::umul_with_overflow && | ||
match(WO->getRHS(), m_APInt(C1)) && !C1->isZero() && WO->hasNUses(2)) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For the same reason as the others, I'd drop the use check on WO as well. Even if there are other uses, we still at least simplify or+icmp to icmp.
(More than two uses would be somewhat unusual anyway, because it means there are redundant extracts. May happen if they don't dominate each other.)
const APInt *C1, *C2; | ||
if (match(&I, | ||
m_c_Or(m_ExtractValue<1>( | ||
m_CombineAnd(m_WithOverflowInst(WO), m_Value(WOV))), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could replace m_WithOverflowInst(WO)
with m_Intrinsic<Intrinsic::umul_with_overflow>(m_Value(X), m_APInt(C1))
here.
The motivation of this pattern is to check whether the product of a
variable and a constant would be mathematically (i.e., as integer
numbers instead of bit vectors) greater than a given constant bound. The
pattern appears to occur when compiling several Rust projects (it seems
to originate from the
smallvec
crate but I have not checked thisfurther).
Unless
c1
is0
, we can transform this pattern intox > c2/c1
withall operations working on unsigned integers. Due to undefined behavior
when an element of a non-splat vector is
0
, the transform is onlyimplemented for scalars and splat vectors.
Alive proof: https://alive2.llvm.org/ce/z/LawTkm
Closes #142674