Skip to content

[InstCombine] fold icmp of select with constants and invertible op #147182

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4332,6 +4332,98 @@ Instruction *InstCombinerImpl::foldICmpInstWithConstantNotInt(ICmpInst &I) {
return nullptr;
}

/// If the APInt C has the same invertible function with Operator RefOp in Pred,
/// return the operands of the function corresponding to each input. Otherwise,
/// return std::nullopt. This is equivalent to saying that Op1 pred Op2 is true
/// exactly when the specified pair of RefOp pred C is true.
/// alive2: https://alive2.llvm.org/ce/z/4jniEb
static std::optional<std::pair<Value *, Value *>>
getInvertibleOperandsWithPredicte(const Operator *RefOp, const APInt C,
CmpInst::Predicate Pred) {
APInt Op1C;
// for BinaryOperator just handle RefOp with constant Operand(1)
if (isa<BinaryOperator>(RefOp)) {
if (isa<ConstantInt>(RefOp->getOperand(1)))
Op1C = cast<ConstantInt>(RefOp->getOperand(1))->getValue();
else
return std::nullopt;
}

auto getOperands = [&](APInt A) -> auto {
return std::make_pair(RefOp->getOperand(0),
ConstantInt::get(RefOp->getOperand(0)->getType(), A));
};
switch (RefOp->getOpcode()) {
default:
break;
case Instruction::Or:
if (cast<PossiblyDisjointInst>(RefOp)->isDisjoint() && ((C & Op1C) == Op1C))
return getOperands(C ^ Op1C);
break;
case Instruction::Add: {
// TODO: add/sub could support nsw/nuw for scmp/ucmp
if (CmpInst::isEquality(Pred))
return getOperands(C - Op1C);
break;
}
case Instruction::Xor: {
if (CmpInst::isEquality(Pred))
return getOperands(C ^ Op1C);
break;
}
case Instruction::Sub: {
if (CmpInst::isEquality(Pred))
return getOperands(C + Op1C);
break;
}
// alive2: https://alive2.llvm.org/ce/z/WPQznV
case Instruction::Shl: {
// Z = shl nsw X, Y <=> X = ashr exact Z, Y
// Z = shl nuw X, Y <=> X = lshr exact Z, Y
if (C.ashr(Op1C).shl(Op1C) == C) {
auto *OBO1 = cast<OverflowingBinaryOperator>(RefOp);
if (OBO1->hasNoSignedWrap())
return getOperands(C.ashr(Op1C));
else if (OBO1->hasNoUnsignedWrap() && !ICmpInst::isSigned(Pred))
return getOperands(C.lshr(Op1C));
}
break;
}
case Instruction::AShr: {
// Z = ashr exact X, Y <=> X = shl nsw Z, Y
auto *PEO1 = cast<PossiblyExactOperator>(RefOp);
if (PEO1->isExact() && C.shl(Op1C).ashr(Op1C) == C)
return getOperands(C.shl(Op1C));
break;
}
case Instruction::LShr: {
// Z = lshr exact X, Y <=> X = shl nuw Z, Y
auto *PEO1 = cast<PossiblyExactOperator>(RefOp);
if (PEO1->isExact() && C.shl(Op1C).lshr(Op1C) == C &&
!ICmpInst::isSigned(Pred))
return getOperands(C.shl(Op1C));
break;
}
case Instruction::SExt: {
unsigned NumBits = RefOp->getType()->getScalarSizeInBits();
unsigned NumBitsOp0 =
RefOp->getOperand(0)->getType()->getScalarSizeInBits();
if (C.trunc(NumBitsOp0).sext(NumBits) == C)
return getOperands(C.trunc(NumBitsOp0));
break;
}
case Instruction::ZExt: {
unsigned NumBits = RefOp->getType()->getScalarSizeInBits();
unsigned NumBitsOp0 =
RefOp->getOperand(0)->getType()->getScalarSizeInBits();
if (C.trunc(NumBitsOp0).zext(NumBits) == C && !ICmpInst::isSigned(Pred))
return getOperands(C.trunc(NumBitsOp0));
break;
}
}
return std::nullopt;
}

Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI,
Value *RHS, const ICmpInst &I) {
// Try to fold the comparison into the select arms, which will cause the
Expand Down Expand Up @@ -4391,6 +4483,24 @@ Instruction *InstCombinerImpl::foldSelectICmp(CmpPredicate Pred, SelectInst *SI,
return SelectInst::Create(SI->getOperand(0), Op1, Op2);
}

// fold select with constants and invertible op
Value *Cond;
const APInt *C1, *C2;
auto *RHSOp = dyn_cast<Operator>(RHS);
if (RHSOp &&
match(SI, m_OneUse(m_Select(m_Value(Cond), m_APInt(C1), m_APInt(C2))))) {
if (auto Values0 = getInvertibleOperandsWithPredicte(RHSOp, *C1, Pred)) {
if (auto Values1 = getInvertibleOperandsWithPredicte(RHSOp, *C2, Pred)) {
assert(Values0->first == Values1->first &&
"Invertible Operand0 mismatch");
auto *NewSI = Builder.CreateSelect(Cond, Values0->second,
Values1->second, SI->getName());
return ICmpInst::Create(Instruction::ICmp, I.getPredicate(), NewSI,
Values0->first, I.getName());
}
}
}

return nullptr;
}

Expand Down
Loading
Loading