diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp index 706cb828acc63..5a8a57acf544c 100644 --- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp +++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp @@ -4864,6 +4864,68 @@ Instruction *InstCombinerImpl::foldNot(BinaryOperator &I) { return nullptr; } +/// Try to transform a chain of XORs into a disjoint OR to expose a constant +/// offset for GEP and other optimizations. This fold transforms: +/// Current = xor(X, K) +/// into: +/// Result = or(xor(X, C), B) +/// where a dominating `xor(X, C)` exists, K = B|C, B&C=0, and the operands +/// of the new `or` are proven disjoint. This is beneficial for address +/// calculations where `or` can be folded into `add`. +static Instruction *transformChainedXorToOrDisjoint(BinaryOperator &I, + InstCombinerImpl &IC) { + Value *X; + ConstantInt *CurrentC; + if (!match(&I, m_Xor(m_Value(X), m_ConstantInt(CurrentC)))) + return nullptr; + + // Find the best dominating base XOR. + BinaryOperator *BestBaseXor = nullptr; + APInt SmallestConst = CurrentC->getValue(); + + for (User *U : X->users()) { + if (U == &I) + continue; + + // Look for sibling instruction: xor(X, SiblingC) + BinaryOperator *SiblingXor; + ConstantInt *SiblingC; + if (!match(U, m_CombineAnd(m_BinOp(SiblingXor), + m_Xor(m_Specific(X), m_ConstantInt(SiblingC))))) + continue; + + const APInt &SiblingConstVal = SiblingC->getValue(); + + // To be a better base, the sibling must have a smaller constant and + // must dominate the instruction we are currently trying to transform. + if (SiblingConstVal.slt(SmallestConst) && + IC.getDominatorTree().dominates(SiblingXor, &I)) { + BestBaseXor = SiblingXor; + SmallestConst = SiblingConstVal; + } + } + + if (!BestBaseXor) + return nullptr; + + // We found a suitable base. Validate the transformation via disjointness + // checks. + const APInt NewConstVal = CurrentC->getValue() - SmallestConst; + + // Check 1: The constant bits must be disjoint. (K = B|C implies B&C=0) + if ((NewConstVal & SmallestConst) != 0) + return nullptr; + + // Check 2: The base value (xor(X,C)) must be disjoint from the new offset + // (B). + if (!IC.MaskedValueIsZero(BestBaseXor, NewConstVal, &I)) + return nullptr; + + // All checks passed. Create the new 'or' instruction. + Constant *NewConst = ConstantInt::get(I.getType(), NewConstVal); + return BinaryOperator::CreateDisjointOr(BestBaseXor, NewConst); +} + // FIXME: We use commutative matchers (m_c_*) for some, but not all, matches // here. We should standardize that construct where it is needed or choose some // other way to ensure that commutated variants of patterns are not missed. @@ -5201,5 +5263,8 @@ Instruction *InstCombinerImpl::visitXor(BinaryOperator &I) { if (Instruction *Res = foldBitwiseLogicWithIntrinsics(I, Builder)) return Res; + if (Instruction *Transformed = transformChainedXorToOrDisjoint(I, *this)) + return Transformed; + return nullptr; } diff --git a/llvm/test/Transforms/InstCombine/xor-to-or.ll b/llvm/test/Transforms/InstCombine/xor-to-or.ll new file mode 100644 index 0000000000000..87b378338e3f9 --- /dev/null +++ b/llvm/test/Transforms/InstCombine/xor-to-or.ll @@ -0,0 +1,31 @@ +; NOTE: Assertions have been autogenerated by utils/update_test_checks.py UTC_ARGS: --version 5 +; RUN: opt -mtriple=amdgcn-amd-amdhsa -passes=instcombine \ +; RUN: -S < %s | FileCheck %s + +define half @xor_to_or_disjoint(i1 %0, ptr %ptr) { +; CHECK-LABEL: define half @xor_to_or_disjoint( +; CHECK-SAME: i1 [[TMP0:%.*]], ptr [[PTR:%.*]]) { +; CHECK-NEXT: [[ENTRY:.*:]] +; CHECK-NEXT: [[ADDR1:%.*]] = select i1 [[TMP0]], i64 32, i64 256 +; CHECK-NEXT: [[GEP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]] +; CHECK-NEXT: [[TMP1:%.*]] = getelementptr i8, ptr [[PTR]], i64 [[ADDR1]] +; CHECK-NEXT: [[GEP2:%.*]] = getelementptr i8, ptr [[TMP1]], i64 2048 +; CHECK-NEXT: [[VAL1:%.*]] = load half, ptr [[GEP1]], align 2 +; CHECK-NEXT: [[VAL2:%.*]] = load half, ptr [[GEP2]], align 2 +; CHECK-NEXT: [[RESULT_H:%.*]] = fadd half [[VAL1]], [[VAL2]] +; CHECK-NEXT: ret half [[RESULT_H]] +; +entry: + %base = select i1 %0, i64 0, i64 288 + %addr1 = xor i64 %base, 32 + %addr2 = xor i64 %base, 2080 + %gep1 = getelementptr i8, ptr %ptr, i64 %addr1 + %gep2 = getelementptr i8, ptr %ptr, i64 %addr2 + %val1 = load half, ptr %gep1 + %val2 = load half, ptr %gep2 + %val1.f = fpext half %val1 to float + %val2.f = fpext half %val2 to float + %sum1.f = fadd float %val1.f, %val2.f + %result.h = fptrunc float %sum1.f to half + ret half %result.h +}