Skip to content

Commit c36b7e2

Browse files
committed
[InstCombine] enhance vector bitwise select matching
(Cond & C) | (~bitcast(Cond) & D) --> bitcast (select Cond, (bc C), (bc D)) This is part of fixing: https://llvm.org/PR34047 That report shows a case where a bitcast is sitting between the select condition candidate and its 'not' value due to current cast canonicalization rules. There's a bitcast type restriction that might be violated in existing matching, but I still need to investigate if that is possible - Alive2 shows we can only do this transform safely when the bitcast is from narrow to wide vector elements (otherwise poison could leak into elements that were safe in the original code): https://alive2.llvm.org/ce/z/Hf66qh Differential Revision: https://reviews.llvm.org/D113035
1 parent 9c63adf commit c36b7e2

File tree

3 files changed

+52
-46
lines changed

3 files changed

+52
-46
lines changed

llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2298,22 +2298,30 @@ Value *InstCombinerImpl::getSelectCondition(Value *A, Value *B) {
22982298
if (!Ty->isIntOrIntVectorTy() || !B->getType()->isIntOrIntVectorTy())
22992299
return nullptr;
23002300

2301-
// We need 0 or all-1's bitmasks.
2302-
if (ComputeNumSignBits(A) != Ty->getScalarSizeInBits())
2303-
return nullptr;
2304-
2305-
// If B is the 'not' value of A, we have our answer.
2301+
// If A is the 'not' operand of B and has enough signbits, we have our answer.
23062302
if (match(B, m_Not(m_Specific(A)))) {
23072303
// If these are scalars or vectors of i1, A can be used directly.
23082304
if (Ty->isIntOrIntVectorTy(1))
23092305
return A;
2310-
return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(Ty));
2306+
2307+
// If we look through a vector bitcast, the caller will bitcast the operands
2308+
// to match the condition's number of bits (N x i1).
2309+
// To make this poison-safe, disallow bitcast from wide element to narrow
2310+
// element. That could allow poison in lanes where it was not present in the
2311+
// original code.
2312+
A = peekThroughBitcast(A);
2313+
unsigned NumSignBits = ComputeNumSignBits(A);
2314+
if (NumSignBits == A->getType()->getScalarSizeInBits() &&
2315+
NumSignBits <= Ty->getScalarSizeInBits())
2316+
return Builder.CreateTrunc(A, CmpInst::makeCmpResultType(A->getType()));
2317+
return nullptr;
23112318
}
23122319

23132320
// If both operands are constants, see if the constants are inverse bitmasks.
23142321
Constant *AConst, *BConst;
23152322
if (match(A, m_Constant(AConst)) && match(B, m_Constant(BConst)))
2316-
if (AConst == ConstantExpr::getNot(BConst))
2323+
if (AConst == ConstantExpr::getNot(BConst) &&
2324+
ComputeNumSignBits(A) == Ty->getScalarSizeInBits())
23172325
return Builder.CreateZExtOrTrunc(A, CmpInst::makeCmpResultType(Ty));
23182326

23192327
// Look for more complex patterns. The 'not' op may be hidden behind various
@@ -2357,10 +2365,17 @@ Value *InstCombinerImpl::matchSelectFromAndOr(Value *A, Value *C, Value *B,
23572365
B = peekThroughBitcast(B, true);
23582366
if (Value *Cond = getSelectCondition(A, B)) {
23592367
// ((bc Cond) & C) | ((bc ~Cond) & D) --> bc (select Cond, (bc C), (bc D))
2368+
// If this is a vector, we may need to cast to match the condition's length.
23602369
// The bitcasts will either all exist or all not exist. The builder will
23612370
// not create unnecessary casts if the types already match.
2362-
Value *BitcastC = Builder.CreateBitCast(C, A->getType());
2363-
Value *BitcastD = Builder.CreateBitCast(D, A->getType());
2371+
Type *SelTy = A->getType();
2372+
if (auto *VecTy = dyn_cast<VectorType>(Cond->getType())) {
2373+
unsigned Elts = VecTy->getElementCount().getKnownMinValue();
2374+
Type *EltTy = Builder.getIntNTy(SelTy->getPrimitiveSizeInBits() / Elts);
2375+
SelTy = VectorType::get(EltTy, VecTy->getElementCount());
2376+
}
2377+
Value *BitcastC = Builder.CreateBitCast(C, SelTy);
2378+
Value *BitcastD = Builder.CreateBitCast(D, SelTy);
23642379
Value *Select = Builder.CreateSelect(Cond, BitcastC, BitcastD);
23652380
return Builder.CreateBitCast(Select, OrigType);
23662381
}

llvm/test/Transforms/InstCombine/logical-select.ll

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -682,15 +682,15 @@ define <4 x i32> @computesignbits_through_two_input_shuffle(<4 x i32> %x, <4 x i
682682
ret <4 x i32> %sel
683683
}
684684

685+
; Bitcast of condition from narrow source element type can be converted to select.
686+
685687
define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d) {
686688
; CHECK-LABEL: @bitcast_vec_cond(
687-
; CHECK-NEXT: [[S:%.*]] = sext <16 x i1> [[COND:%.*]] to <16 x i8>
688-
; CHECK-NEXT: [[T9:%.*]] = bitcast <16 x i8> [[S]] to <2 x i64>
689-
; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i64> [[T9]], <i64 -1, i64 -1>
690-
; CHECK-NEXT: [[T11:%.*]] = and <2 x i64> [[NOTT9]], [[C:%.*]]
691-
; CHECK-NEXT: [[T12:%.*]] = and <2 x i64> [[T9]], [[D:%.*]]
692-
; CHECK-NEXT: [[R:%.*]] = or <2 x i64> [[T11]], [[T12]]
693-
; CHECK-NEXT: ret <2 x i64> [[R]]
689+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i64> [[D:%.*]] to <16 x i8>
690+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i64> [[C:%.*]] to <16 x i8>
691+
; CHECK-NEXT: [[TMP3:%.*]] = select <16 x i1> [[COND:%.*]], <16 x i8> [[TMP1]], <16 x i8> [[TMP2]]
692+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i8> [[TMP3]] to <2 x i64>
693+
; CHECK-NEXT: ret <2 x i64> [[TMP4]]
694694
;
695695
%s = sext <16 x i1> %cond to <16 x i8>
696696
%t9 = bitcast <16 x i8> %s to <2 x i64>
@@ -701,6 +701,8 @@ define <2 x i64> @bitcast_vec_cond(<16 x i1> %cond, <2 x i64> %c, <2 x i64> %d)
701701
ret <2 x i64> %r
702702
}
703703

704+
; Negative test - bitcast of condition from wide source element type cannot be converted to select.
705+
704706
define <8 x i3> @bitcast_vec_cond_commute1(<3 x i1> %cond, <8 x i3> %pc, <8 x i3> %d) {
705707
; CHECK-LABEL: @bitcast_vec_cond_commute1(
706708
; CHECK-NEXT: [[C:%.*]] = mul <8 x i3> [[PC:%.*]], [[PC]]
@@ -726,13 +728,11 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
726728
; CHECK-LABEL: @bitcast_vec_cond_commute2(
727729
; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
728730
; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
729-
; CHECK-NEXT: [[S:%.*]] = sext <4 x i1> [[COND:%.*]] to <4 x i8>
730-
; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
731-
; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
732-
; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
733-
; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
734-
; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
735-
; CHECK-NEXT: ret <2 x i16> [[R]]
731+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
732+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
733+
; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[COND:%.*]], <4 x i8> [[TMP1]], <4 x i8> [[TMP2]]
734+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
735+
; CHECK-NEXT: ret <2 x i16> [[TMP4]]
736736
;
737737
%c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
738738
%d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization
@@ -745,17 +745,18 @@ define <2 x i16> @bitcast_vec_cond_commute2(<4 x i1> %cond, <2 x i16> %pc, <2 x
745745
ret <2 x i16> %r
746746
}
747747

748+
; Condition doesn't have to be a bool vec - just all signbits.
749+
748750
define <2 x i16> @bitcast_vec_cond_commute3(<4 x i8> %cond, <2 x i16> %pc, <2 x i16> %pd) {
749751
; CHECK-LABEL: @bitcast_vec_cond_commute3(
750752
; CHECK-NEXT: [[C:%.*]] = mul <2 x i16> [[PC:%.*]], [[PC]]
751753
; CHECK-NEXT: [[D:%.*]] = mul <2 x i16> [[PD:%.*]], [[PD]]
752-
; CHECK-NEXT: [[S:%.*]] = ashr <4 x i8> [[COND:%.*]], <i8 7, i8 7, i8 7, i8 7>
753-
; CHECK-NEXT: [[T9:%.*]] = bitcast <4 x i8> [[S]] to <2 x i16>
754-
; CHECK-NEXT: [[NOTT9:%.*]] = xor <2 x i16> [[T9]], <i16 -1, i16 -1>
755-
; CHECK-NEXT: [[T11:%.*]] = and <2 x i16> [[C]], [[NOTT9]]
756-
; CHECK-NEXT: [[T12:%.*]] = and <2 x i16> [[D]], [[T9]]
757-
; CHECK-NEXT: [[R:%.*]] = or <2 x i16> [[T11]], [[T12]]
758-
; CHECK-NEXT: ret <2 x i16> [[R]]
754+
; CHECK-NEXT: [[DOTNOT:%.*]] = icmp sgt <4 x i8> [[COND:%.*]], <i8 -1, i8 -1, i8 -1, i8 -1>
755+
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <2 x i16> [[D]] to <4 x i8>
756+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <2 x i16> [[C]] to <4 x i8>
757+
; CHECK-NEXT: [[TMP3:%.*]] = select <4 x i1> [[DOTNOT]], <4 x i8> [[TMP2]], <4 x i8> [[TMP1]]
758+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <4 x i8> [[TMP3]] to <2 x i16>
759+
; CHECK-NEXT: ret <2 x i16> [[TMP4]]
759760
;
760761
%c = mul <2 x i16> %pc, %pc ; thwart complexity-based canonicalization
761762
%d = mul <2 x i16> %pd, %pd ; thwart complexity-based canonicalization

llvm/test/Transforms/PhaseOrdering/X86/vector-math.ll

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,15 +68,9 @@ define internal <2 x i64> @_mm_set_epi32(i32 %__i3, i32 %__i2, i32 %__i1, i32 %_
6868
define <2 x i64> @abs_v4i32(<2 x i64> %x) {
6969
; CHECK-LABEL: @abs_v4i32(
7070
; CHECK-NEXT: [[T1_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
71-
; CHECK-NEXT: [[SUB_I:%.*]] = sub <4 x i32> zeroinitializer, [[T1_I]]
72-
; CHECK-NEXT: [[T1_I_LOBIT:%.*]] = ashr <4 x i32> [[T1_I]], <i32 31, i32 31, i32 31, i32 31>
73-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i32> [[T1_I_LOBIT]] to <2 x i64>
74-
; CHECK-NEXT: [[T2_I_I:%.*]] = xor <2 x i64> [[TMP1]], <i64 -1, i64 -1>
75-
; CHECK-NEXT: [[AND_I_I1:%.*]] = and <4 x i32> [[T1_I_LOBIT]], [[SUB_I]]
76-
; CHECK-NEXT: [[AND_I_I:%.*]] = bitcast <4 x i32> [[AND_I_I1]] to <2 x i64>
77-
; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
78-
; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
79-
; CHECK-NEXT: ret <2 x i64> [[OR_I_I]]
71+
; CHECK-NEXT: [[TMP1:%.*]] = tail call <4 x i32> @llvm.abs.v4i32(<4 x i32> [[T1_I]], i1 false)
72+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
73+
; CHECK-NEXT: ret <2 x i64> [[TMP2]]
8074
;
8175
%call = call <2 x i64> @_mm_set1_epi32(i32 -1)
8276
%call1 = call <2 x i64> @_mm_setzero_si128()
@@ -90,13 +84,9 @@ define <2 x i64> @max_v4i32(<2 x i64> %x, <2 x i64> %y) {
9084
; CHECK-NEXT: [[T0_I_I:%.*]] = bitcast <2 x i64> [[X:%.*]] to <4 x i32>
9185
; CHECK-NEXT: [[T1_I_I:%.*]] = bitcast <2 x i64> [[Y:%.*]] to <4 x i32>
9286
; CHECK-NEXT: [[CMP_I_I:%.*]] = icmp sgt <4 x i32> [[T0_I_I]], [[T1_I_I]]
93-
; CHECK-NEXT: [[SEXT_I_I:%.*]] = sext <4 x i1> [[CMP_I_I]] to <4 x i32>
94-
; CHECK-NEXT: [[T2_I_I:%.*]] = bitcast <4 x i32> [[SEXT_I_I]] to <2 x i64>
95-
; CHECK-NEXT: [[NEG_I_I:%.*]] = xor <2 x i64> [[T2_I_I]], <i64 -1, i64 -1>
96-
; CHECK-NEXT: [[AND_I_I:%.*]] = and <2 x i64> [[NEG_I_I]], [[Y]]
97-
; CHECK-NEXT: [[AND_I1_I:%.*]] = and <2 x i64> [[T2_I_I]], [[X]]
98-
; CHECK-NEXT: [[OR_I_I:%.*]] = or <2 x i64> [[AND_I1_I]], [[AND_I_I]]
99-
; CHECK-NEXT: ret <2 x i64> [[OR_I_I]]
87+
; CHECK-NEXT: [[TMP1:%.*]] = select <4 x i1> [[CMP_I_I]], <4 x i32> [[T0_I_I]], <4 x i32> [[T1_I_I]]
88+
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <4 x i32> [[TMP1]] to <2 x i64>
89+
; CHECK-NEXT: ret <2 x i64> [[TMP2]]
10090
;
10191
%call = call <2 x i64> @cmpgt_i32_sel_m128i(<2 x i64> %x, <2 x i64> %y, <2 x i64> %y, <2 x i64> %x)
10292
ret <2 x i64> %call

0 commit comments

Comments
 (0)