Skip to content

Commit 417ab37

Browse files
authored
[ConstantFolding] Fold deinterleave2 of any splat vector not just zeroinitializer (#144144)
While there remove an unnecessary dyn_cast from Constant to Constant. Reverse a branch condition into an early out to reduce nesting.
1 parent 7f69cd5 commit 417ab37

File tree

2 files changed

+37
-22
lines changed

2 files changed

+37
-22
lines changed

llvm/lib/Analysis/ConstantFolding.cpp

Lines changed: 21 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -3990,31 +3990,30 @@ ConstantFoldStructCall(StringRef Name, Intrinsic::ID IntrinsicID,
39903990
return ConstantStruct::get(StTy, SinResult, CosResult);
39913991
}
39923992
case Intrinsic::vector_deinterleave2: {
3993-
auto *Vec = dyn_cast<Constant>(Operands[0]);
3994-
if (!Vec)
3995-
return nullptr;
3996-
3993+
auto *Vec = Operands[0];
39973994
auto *VecTy = cast<VectorType>(Vec->getType());
3998-
unsigned NumElements = VecTy->getElementCount().getKnownMinValue() / 2;
3999-
if (isa<ConstantAggregateZero>(Vec)) {
4000-
auto *HalfVecTy = VectorType::getHalfElementsVectorType(VecTy);
4001-
return ConstantStruct::get(StTy, ConstantAggregateZero::get(HalfVecTy),
4002-
ConstantAggregateZero::get(HalfVecTy));
3995+
3996+
if (auto *EltC = Vec->getSplatValue()) {
3997+
ElementCount HalfEC = VecTy->getElementCount().divideCoefficientBy(2);
3998+
auto *HalfVec = ConstantVector::getSplat(HalfEC, EltC);
3999+
return ConstantStruct::get(StTy, HalfVec, HalfVec);
40034000
}
4004-
if (isa<FixedVectorType>(Vec->getType())) {
4005-
SmallVector<Constant *, 4> Res0(NumElements), Res1(NumElements);
4006-
for (unsigned I = 0; I < NumElements; ++I) {
4007-
Constant *Elt0 = Vec->getAggregateElement(2 * I);
4008-
Constant *Elt1 = Vec->getAggregateElement(2 * I + 1);
4009-
if (!Elt0 || !Elt1)
4010-
return nullptr;
4011-
Res0[I] = Elt0;
4012-
Res1[I] = Elt1;
4013-
}
4014-
return ConstantStruct::get(StTy, ConstantVector::get(Res0),
4015-
ConstantVector::get(Res1));
4001+
4002+
if (!isa<FixedVectorType>(Vec->getType()))
4003+
return nullptr;
4004+
4005+
unsigned NumElements = VecTy->getElementCount().getFixedValue() / 2;
4006+
SmallVector<Constant *, 4> Res0(NumElements), Res1(NumElements);
4007+
for (unsigned I = 0; I < NumElements; ++I) {
4008+
Constant *Elt0 = Vec->getAggregateElement(2 * I);
4009+
Constant *Elt1 = Vec->getAggregateElement(2 * I + 1);
4010+
if (!Elt0 || !Elt1)
4011+
return nullptr;
4012+
Res0[I] = Elt0;
4013+
Res1[I] = Elt1;
40164014
}
4017-
return nullptr;
4015+
return ConstantStruct::get(StTy, ConstantVector::get(Res0),
4016+
ConstantVector::get(Res1));
40184017
}
40194018
default:
40204019
// TODO: Constant folding of vector intrinsics that fall through here does

llvm/test/Transforms/InstSimplify/ConstProp/vector-calls.ll

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,19 @@ define {<vscale x 4 x i32>, <vscale x 4 x i32>} @fold_scalable_vector_deinterlea
6666
%1 = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @llvm.vector.deinterleave2.v4i32.v8i32(<vscale x 8 x i32> zeroinitializer)
6767
ret {<vscale x 4 x i32>, <vscale x 4 x i32>} %1
6868
}
69+
70+
define {<vscale x 4 x i32>, <vscale x 4 x i32>} @fold_scalable_vector_deinterleave2_splat() {
71+
; CHECK-LABEL: define { <vscale x 4 x i32>, <vscale x 4 x i32> } @fold_scalable_vector_deinterleave2_splat() {
72+
; CHECK-NEXT: ret { <vscale x 4 x i32>, <vscale x 4 x i32> } { <vscale x 4 x i32> splat (i32 1), <vscale x 4 x i32> splat (i32 1) }
73+
;
74+
%1 = call {<vscale x 4 x i32>, <vscale x 4 x i32>} @llvm.vector.deinterleave2.v4i32.v8i32(<vscale x 8 x i32> splat (i32 1))
75+
ret {<vscale x 4 x i32>, <vscale x 4 x i32>} %1
76+
}
77+
78+
define {<vscale x 4 x float>, <vscale x 4 x float>} @fold_scalable_vector_deinterleave2_splatfp() {
79+
; CHECK-LABEL: define { <vscale x 4 x float>, <vscale x 4 x float> } @fold_scalable_vector_deinterleave2_splatfp() {
80+
; CHECK-NEXT: ret { <vscale x 4 x float>, <vscale x 4 x float> } { <vscale x 4 x float> splat (float 1.000000e+00), <vscale x 4 x float> splat (float 1.000000e+00) }
81+
;
82+
%1 = call {<vscale x 4 x float>, <vscale x 4 x float>} @llvm.vector.deinterleave2.v4f32.v8f32(<vscale x 8 x float> splat (float 1.0))
83+
ret {<vscale x 4 x float>, <vscale x 4 x float>} %1
84+
}

0 commit comments

Comments
 (0)