Skip to content

Commit d05634d

Browse files
vortex73RKSimon
andauthored
[VectorCombine] Fold bitwise operations of bitcasts into bitcast of bitwise operation (#137322)
Currently, LLVM fails to convert certain pblendvb intrinsics into select instructions when the blend mask is derived from complex boolean logic operations. This occurs even when the mask is ultimately based on sign-extended comparison results, preventing further optimization opportunities. Fixes #66513 --------- Co-authored-by: Simon Pilgrim <llvm-dev@redking.me.uk>
1 parent cfdc4c4 commit d05634d

File tree

3 files changed

+89
-36
lines changed

3 files changed

+89
-36
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ class VectorCombine {
113113
bool foldInsExtFNeg(Instruction &I);
114114
bool foldInsExtBinop(Instruction &I);
115115
bool foldInsExtVectorToShuffle(Instruction &I);
116+
bool foldBitOpOfBitcasts(Instruction &I);
116117
bool foldBitcastShuffle(Instruction &I);
117118
bool scalarizeOpOrCmp(Instruction &I);
118119
bool scalarizeVPIntrinsic(Instruction &I);
@@ -803,6 +804,66 @@ bool VectorCombine::foldInsExtBinop(Instruction &I) {
803804
return true;
804805
}
805806

807+
bool VectorCombine::foldBitOpOfBitcasts(Instruction &I) {
808+
// Match: bitop(bitcast(x), bitcast(y)) -> bitcast(bitop(x, y))
809+
Value *LHSSrc, *RHSSrc;
810+
if (!match(&I, m_BitwiseLogic(m_BitCast(m_Value(LHSSrc)),
811+
m_BitCast(m_Value(RHSSrc)))))
812+
return false;
813+
814+
// Source types must match
815+
if (LHSSrc->getType() != RHSSrc->getType())
816+
return false;
817+
if (!LHSSrc->getType()->getScalarType()->isIntegerTy())
818+
return false;
819+
820+
// Only handle vector types
821+
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
822+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
823+
if (!SrcVecTy || !DstVecTy)
824+
return false;
825+
826+
// Same total bit width
827+
assert(SrcVecTy->getPrimitiveSizeInBits() ==
828+
DstVecTy->getPrimitiveSizeInBits() &&
829+
"Bitcast should preserve total bit width");
830+
831+
// Cost Check :
832+
// OldCost = bitlogic + 2*bitcasts
833+
// NewCost = bitlogic + bitcast
834+
auto *BinOp = cast<BinaryOperator>(&I);
835+
InstructionCost OldCost =
836+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), DstVecTy) +
837+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, LHSSrc->getType(),
838+
TTI::CastContextHint::None) +
839+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, RHSSrc->getType(),
840+
TTI::CastContextHint::None);
841+
InstructionCost NewCost =
842+
TTI.getArithmeticInstrCost(BinOp->getOpcode(), SrcVecTy) +
843+
TTI.getCastInstrCost(Instruction::BitCast, DstVecTy, SrcVecTy,
844+
TTI::CastContextHint::None);
845+
846+
LLVM_DEBUG(dbgs() << "Found a bitwise logic op of bitcasted values: " << I
847+
<< "\n OldCost: " << OldCost << " vs NewCost: " << NewCost
848+
<< "\n");
849+
850+
if (NewCost > OldCost)
851+
return false;
852+
853+
// Create the operation on the source type
854+
Value *NewOp = Builder.CreateBinOp(BinOp->getOpcode(), LHSSrc, RHSSrc,
855+
BinOp->getName() + ".inner");
856+
if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
857+
NewBinOp->copyIRFlags(BinOp);
858+
859+
Worklist.pushValue(NewOp);
860+
861+
// Bitcast the result back
862+
Value *Result = Builder.CreateBitCast(NewOp, I.getType());
863+
replaceValue(I, *Result);
864+
return true;
865+
}
866+
806867
/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
807868
/// destination type followed by shuffle. This can enable further transforms by
808869
/// moving bitcasts or shuffles together.
@@ -3629,6 +3690,11 @@ bool VectorCombine::run() {
36293690
case Instruction::BitCast:
36303691
MadeChange |= foldBitcastShuffle(I);
36313692
break;
3693+
case Instruction::And:
3694+
case Instruction::Or:
3695+
case Instruction::Xor:
3696+
MadeChange |= foldBitOpOfBitcasts(I);
3697+
break;
36323698
default:
36333699
MadeChange |= shrinkType(I);
36343700
break;

llvm/test/Transforms/PhaseOrdering/X86/blendv-select.ll

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -477,30 +477,22 @@ define <2 x i64> @PR66513(<2 x i64> %a, <2 x i64> %b, <2 x i64> %c, <2 x i64> %s
477477
; CHECK-LABEL: @PR66513(
478478
; CHECK-NEXT: [[I:%.*]] = bitcast <2 x i64> [[A:%.*]] to <4 x i32>
479479
; CHECK-NEXT: [[CMP_I23:%.*]] = icmp sgt <4 x i32> [[I]], zeroinitializer
480-
; CHECK-NEXT: [[SEXT_I24:%.*]] = sext <4 x i1> [[CMP_I23]] to <4 x i32>
481-
; CHECK-NEXT: [[I1:%.*]] = bitcast <4 x i32> [[SEXT_I24]] to <2 x i64>
482480
; CHECK-NEXT: [[I2:%.*]] = bitcast <2 x i64> [[B:%.*]] to <4 x i32>
483481
; CHECK-NEXT: [[CMP_I21:%.*]] = icmp sgt <4 x i32> [[I2]], zeroinitializer
484-
; CHECK-NEXT: [[SEXT_I22:%.*]] = sext <4 x i1> [[CMP_I21]] to <4 x i32>
485-
; CHECK-NEXT: [[I3:%.*]] = bitcast <4 x i32> [[SEXT_I22]] to <2 x i64>
486482
; CHECK-NEXT: [[I4:%.*]] = bitcast <2 x i64> [[C:%.*]] to <4 x i32>
487483
; CHECK-NEXT: [[CMP_I:%.*]] = icmp sgt <4 x i32> [[I4]], zeroinitializer
488-
; CHECK-NEXT: [[SEXT_I:%.*]] = sext <4 x i1> [[CMP_I]] to <4 x i32>
484+
; CHECK-NEXT: [[NARROW:%.*]] = select <4 x i1> [[CMP_I21]], <4 x i1> [[CMP_I23]], <4 x i1> zeroinitializer
485+
; CHECK-NEXT: [[XOR_I_INNER1:%.*]] = xor <4 x i1> [[NARROW]], [[CMP_I]]
486+
; CHECK-NEXT: [[NARROW3:%.*]] = select <4 x i1> [[CMP_I23]], <4 x i1> [[XOR_I_INNER1]], <4 x i1> zeroinitializer
487+
; CHECK-NEXT: [[AND_I25_INNER2:%.*]] = and <4 x i1> [[XOR_I_INNER1]], [[CMP_I21]]
488+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[SRC:%.*]] to <4 x i32>
489+
; CHECK-NEXT: [[TMP2:%.*]] = select <4 x i1> [[NARROW]], <4 x i32> [[TMP1]], <4 x i32> zeroinitializer
490+
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <2 x i64> [[A]] to <4 x i32>
491+
; CHECK-NEXT: [[TMP4:%.*]] = select <4 x i1> [[NARROW3]], <4 x i32> [[TMP3]], <4 x i32> [[TMP2]]
492+
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <2 x i64> [[B]] to <4 x i32>
493+
; CHECK-NEXT: [[SEXT_I:%.*]] = select <4 x i1> [[AND_I25_INNER2]], <4 x i32> [[TMP5]], <4 x i32> [[TMP4]]
489494
; CHECK-NEXT: [[I5:%.*]] = bitcast <4 x i32> [[SEXT_I]] to <2 x i64>
490-
; CHECK-NEXT: [[AND_I27:%.*]] = and <2 x i64> [[I3]], [[I1]]
491-
; CHECK-NEXT: [[XOR_I:%.*]] = xor <2 x i64> [[AND_I27]], [[I5]]
492-
; CHECK-NEXT: [[AND_I26:%.*]] = and <2 x i64> [[XOR_I]], [[I1]]
493-
; CHECK-NEXT: [[AND_I25:%.*]] = and <2 x i64> [[XOR_I]], [[I3]]
494-
; CHECK-NEXT: [[AND_I:%.*]] = and <2 x i64> [[AND_I27]], [[SRC:%.*]]
495-
; CHECK-NEXT: [[I6:%.*]] = bitcast <2 x i64> [[AND_I]] to <16 x i8>
496-
; CHECK-NEXT: [[I7:%.*]] = bitcast <2 x i64> [[A]] to <16 x i8>
497-
; CHECK-NEXT: [[I8:%.*]] = bitcast <2 x i64> [[AND_I26]] to <16 x i8>
498-
; CHECK-NEXT: [[I9:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I6]], <16 x i8> [[I7]], <16 x i8> [[I8]])
499-
; CHECK-NEXT: [[I12:%.*]] = bitcast <2 x i64> [[B]] to <16 x i8>
500-
; CHECK-NEXT: [[I13:%.*]] = bitcast <2 x i64> [[AND_I25]] to <16 x i8>
501-
; CHECK-NEXT: [[I14:%.*]] = tail call <16 x i8> @llvm.x86.sse41.pblendvb(<16 x i8> [[I9]], <16 x i8> [[I12]], <16 x i8> [[I13]])
502-
; CHECK-NEXT: [[I15:%.*]] = bitcast <16 x i8> [[I14]] to <2 x i64>
503-
; CHECK-NEXT: ret <2 x i64> [[I15]]
495+
; CHECK-NEXT: ret <2 x i64> [[I5]]
504496
;
505497
%i = bitcast <2 x i64> %a to <4 x i32>
506498
%cmp.i23 = icmp sgt <4 x i32> %i, zeroinitializer

llvm/test/Transforms/VectorCombine/AArch64/shrink-types.ll

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,8 @@ define i32 @test_and(<16 x i32> %a, ptr %b) {
77
; CHECK-LABEL: @test_and(
88
; CHECK-NEXT: entry:
99
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
10-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A:%.*]] to <16 x i8>
11-
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD]], [[TMP0]]
12-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
10+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
11+
; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A:%.*]]
1312
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
1413
; CHECK-NEXT: ret i32 [[TMP3]]
1514
;
@@ -26,9 +25,8 @@ define i32 @test_mask_or(<16 x i32> %a, ptr %b) {
2625
; CHECK-NEXT: entry:
2726
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
2827
; CHECK-NEXT: [[A_MASKED:%.*]] = and <16 x i32> [[A:%.*]], splat (i32 16)
29-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_MASKED]] to <16 x i8>
30-
; CHECK-NEXT: [[TMP1:%.*]] = or <16 x i8> [[WIDE_LOAD]], [[TMP0]]
31-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
28+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
29+
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i32> [[TMP0]], [[A_MASKED]]
3230
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
3331
; CHECK-NEXT: ret i32 [[TMP3]]
3432
;
@@ -47,15 +45,13 @@ define i32 @multiuse(<16 x i32> %u, <16 x i32> %v, ptr %b) {
4745
; CHECK-NEXT: [[U_MASKED:%.*]] = and <16 x i32> [[U:%.*]], splat (i32 255)
4846
; CHECK-NEXT: [[V_MASKED:%.*]] = and <16 x i32> [[V:%.*]], splat (i32 255)
4947
; CHECK-NEXT: [[WIDE_LOAD:%.*]] = load <16 x i8>, ptr [[B:%.*]], align 1
50-
; CHECK-NEXT: [[TMP0:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
51-
; CHECK-NEXT: [[TMP1:%.*]] = trunc <16 x i32> [[V_MASKED]] to <16 x i8>
52-
; CHECK-NEXT: [[TMP2:%.*]] = or <16 x i8> [[TMP0]], [[TMP1]]
53-
; CHECK-NEXT: [[TMP3:%.*]] = zext <16 x i8> [[TMP2]] to <16 x i32>
54-
; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i8> [[WIDE_LOAD]], splat (i8 15)
55-
; CHECK-NEXT: [[TMP5:%.*]] = trunc <16 x i32> [[U_MASKED]] to <16 x i8>
56-
; CHECK-NEXT: [[TMP6:%.*]] = or <16 x i8> [[TMP4]], [[TMP5]]
48+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD]] to <16 x i32>
49+
; CHECK-NEXT: [[TMP6:%.*]] = lshr <16 x i8> [[WIDE_LOAD]], splat (i8 4)
5750
; CHECK-NEXT: [[TMP7:%.*]] = zext <16 x i8> [[TMP6]] to <16 x i32>
58-
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP7]]
51+
; CHECK-NEXT: [[TMP3:%.*]] = or <16 x i32> [[TMP7]], [[V_MASKED]]
52+
; CHECK-NEXT: [[TMP4:%.*]] = and <16 x i32> [[TMP0]], splat (i32 15)
53+
; CHECK-NEXT: [[TMP5:%.*]] = or <16 x i32> [[TMP4]], [[U_MASKED]]
54+
; CHECK-NEXT: [[TMP8:%.*]] = add nuw nsw <16 x i32> [[TMP3]], [[TMP5]]
5955
; CHECK-NEXT: [[TMP9:%.*]] = call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP8]])
6056
; CHECK-NEXT: ret i32 [[TMP9]]
6157
;
@@ -81,9 +77,8 @@ define i32 @phi_bug(<16 x i32> %a, ptr %b) {
8177
; CHECK: vector.body:
8278
; CHECK-NEXT: [[A_PHI:%.*]] = phi <16 x i32> [ [[A:%.*]], [[ENTRY:%.*]] ]
8379
; CHECK-NEXT: [[WIDE_LOAD_PHI:%.*]] = phi <16 x i8> [ [[WIDE_LOAD]], [[ENTRY]] ]
84-
; CHECK-NEXT: [[TMP0:%.*]] = trunc <16 x i32> [[A_PHI]] to <16 x i8>
85-
; CHECK-NEXT: [[TMP1:%.*]] = and <16 x i8> [[WIDE_LOAD_PHI]], [[TMP0]]
86-
; CHECK-NEXT: [[TMP2:%.*]] = zext <16 x i8> [[TMP1]] to <16 x i32>
80+
; CHECK-NEXT: [[TMP0:%.*]] = zext <16 x i8> [[WIDE_LOAD_PHI]] to <16 x i32>
81+
; CHECK-NEXT: [[TMP2:%.*]] = and <16 x i32> [[TMP0]], [[A_PHI]]
8782
; CHECK-NEXT: [[TMP3:%.*]] = tail call i32 @llvm.vector.reduce.add.v16i32(<16 x i32> [[TMP2]])
8883
; CHECK-NEXT: ret i32 [[TMP3]]
8984
;

0 commit comments

Comments
 (0)