Skip to content

Commit b0473c5

Browse files
[InstCombine] Pull extract through broadcast (#143380)
The change adds a new instcombine pattern, and associated test, for patterns like this: ``` %3 = shufflevector <2 x float> %1, <2 x float> poison, <4 x i32> zeroinitializer %4 = extractelement <4 x float> %3, i64 %idx ``` The shufflevector has a splat, or broadcast, mask, so the extractelement simply must be the first element of %1, so we transform this to ``` %2 = extractelement <2 x float> %1, i64 0 ```
1 parent cc6a864 commit b0473c5

File tree

6 files changed

+96
-40
lines changed

6 files changed

+96
-40
lines changed

llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -542,27 +542,39 @@ Instruction *InstCombinerImpl::visitExtractElementInst(ExtractElementInst &EI) {
542542
}
543543
}
544544
} else if (auto *SVI = dyn_cast<ShuffleVectorInst>(I)) {
545-
// If this is extracting an element from a shufflevector, figure out where
546-
// it came from and extract from the appropriate input element instead.
547-
// Restrict the following transformation to fixed-length vector.
548-
if (isa<FixedVectorType>(SVI->getType()) && isa<ConstantInt>(Index)) {
549-
int SrcIdx =
550-
SVI->getMaskValue(cast<ConstantInt>(Index)->getZExtValue());
551-
Value *Src;
552-
unsigned LHSWidth = cast<FixedVectorType>(SVI->getOperand(0)->getType())
553-
->getNumElements();
554-
555-
if (SrcIdx < 0)
556-
return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
557-
if (SrcIdx < (int)LHSWidth)
558-
Src = SVI->getOperand(0);
559-
else {
560-
SrcIdx -= LHSWidth;
561-
Src = SVI->getOperand(1);
545+
int SplatIndex = getSplatIndex(SVI->getShuffleMask());
546+
// We know the all-0 splat must be reading from the first operand, even
547+
// in the case of scalable vectors (vscale is always > 0).
548+
if (SplatIndex == 0)
549+
return ExtractElementInst::Create(SVI->getOperand(0),
550+
Builder.getInt64(0));
551+
552+
if (isa<FixedVectorType>(SVI->getType())) {
553+
std::optional<int> SrcIdx;
554+
// getSplatIndex returns -1 to mean not-found.
555+
if (SplatIndex != -1)
556+
SrcIdx = SplatIndex;
557+
else if (ConstantInt *CI = dyn_cast<ConstantInt>(Index))
558+
SrcIdx = SVI->getMaskValue(CI->getZExtValue());
559+
560+
if (SrcIdx) {
561+
Value *Src;
562+
unsigned LHSWidth =
563+
cast<FixedVectorType>(SVI->getOperand(0)->getType())
564+
->getNumElements();
565+
566+
if (*SrcIdx < 0)
567+
return replaceInstUsesWith(EI, PoisonValue::get(EI.getType()));
568+
if (*SrcIdx < (int)LHSWidth)
569+
Src = SVI->getOperand(0);
570+
else {
571+
*SrcIdx -= LHSWidth;
572+
Src = SVI->getOperand(1);
573+
}
574+
Type *Int64Ty = Type::getInt64Ty(EI.getContext());
575+
return ExtractElementInst::Create(
576+
Src, ConstantInt::get(Int64Ty, *SrcIdx, false));
562577
}
563-
Type *Int64Ty = Type::getInt64Ty(EI.getContext());
564-
return ExtractElementInst::Create(
565-
Src, ConstantInt::get(Int64Ty, SrcIdx, false));
566578
}
567579
} else if (auto *CI = dyn_cast<CastInst>(I)) {
568580
// Canonicalize extractelement(cast) -> cast(extractelement).
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
; RUN: opt -passes=instcombine -S < %s | FileCheck %s
2+
3+
define float @extract_from_zero_init_shuffle(<2 x float> %vec, i64 %idx) {
4+
; CHECK-LABEL: @extract_from_zero_init_shuffle(
5+
; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 0
6+
; CHECK-NEXT: ret float %extract
7+
;
8+
%shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> zeroinitializer
9+
%extract = extractelement <4 x float> %shuffle, i64 %idx
10+
ret float %extract
11+
}
12+
13+
14+
define float @extract_from_general_splat(<2 x float> %vec, i64 %idx) {
15+
; CHECK-LABEL: @extract_from_general_splat(
16+
; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
17+
; CHECK-NEXT: ret float %extract
18+
;
19+
%shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 1, i32 1, i32 1>
20+
%extract = extractelement <4 x float> %shuffle, i64 %idx
21+
ret float %extract
22+
}
23+
24+
define float @extract_from_general_scalable_splat(<vscale x 2 x float> %vec, i64 %idx) {
25+
; CHECK-LABEL: @extract_from_general_scalable_splat(
26+
; CHECK-NEXT: %extract = extractelement <vscale x 2 x float> %vec, i64 0
27+
; CHECK-NEXT: ret float %extract
28+
;
29+
%shuffle = shufflevector <vscale x 2 x float> %vec, <vscale x 2 x float> poison, <vscale x 4 x i32> zeroinitializer
30+
%extract = extractelement <vscale x 4 x float> %shuffle, i64 %idx
31+
ret float %extract
32+
}
33+
34+
define float @extract_from_splat_with_poison_0(<2 x float> %vec, i64 %idx) {
35+
; CHECK-LABEL: @extract_from_splat_with_poison_0(
36+
; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
37+
; CHECK-NEXT: ret float %extract
38+
;
39+
%shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 poison, i32 1, i32 1, i32 1>
40+
%extract = extractelement <4 x float> %shuffle, i64 %idx
41+
ret float %extract
42+
}
43+
44+
define float @extract_from_splat_with_poison_1(<2 x float> %vec, i64 %idx) {
45+
; CHECK-LABEL: @extract_from_splat_with_poison_1(
46+
; CHECK-NEXT: %extract = extractelement <2 x float> %vec, i64 1
47+
; CHECK-NEXT: ret float %extract
48+
;
49+
%shuffle = shufflevector <2 x float> %vec, <2 x float> poison, <4 x i32> <i32 1, i32 poison, i32 1, i32 1>
50+
%extract = extractelement <4 x float> %shuffle, i64 %idx
51+
ret float %extract
52+
}

llvm/test/Transforms/InstCombine/vec_shuffle-inseltpoison.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,7 @@ define float @test6(<4 x float> %X) {
6161

6262
define float @testvscale6(<vscale x 4 x float> %X) {
6363
; CHECK-LABEL: @testvscale6(
64-
; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
65-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
64+
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
6665
; CHECK-NEXT: ret float [[R]]
6766
;
6867
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>

llvm/test/Transforms/InstCombine/vec_shuffle.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ define float @test6(<4 x float> %X) {
6767

6868
define float @testvscale6(<vscale x 4 x float> %X) {
6969
; CHECK-LABEL: @testvscale6(
70-
; CHECK-NEXT: [[T2:%.*]] = shufflevector <vscale x 4 x float> [[X:%.*]], <vscale x 4 x float> poison, <vscale x 4 x i32> zeroinitializer
71-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[T2]], i64 0
70+
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x float> [[X:%.*]], i64 0
7271
; CHECK-NEXT: ret float [[R]]
7372
;
7473
%X1 = bitcast <vscale x 4 x float> %X to <vscale x 4 x i32>

llvm/test/Transforms/InstCombine/vscale_extractelement-inseltpoison.ll

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ define i8 @extractelement_bitcast_insert_extra_use_bitcast(<vscale x 2 x i32> %a
8989
ret i8 %r
9090
}
9191

92+
; while it may be that the extract is out-of-bounds, any valid index
93+
; is going to yield %v (because the mask is all-zeros).
94+
9295
define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
9396
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
94-
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
95-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
96-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
97-
; CHECK-NEXT: ret i32 [[R]]
97+
; CHECK-NEXT: ret i32 [[V:%.*]]
9898
;
9999
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
100100
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
@@ -104,10 +104,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
104104

105105
define i32 @extractelement_shuffle_invalid_index(i32 %v) {
106106
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
107-
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> poison, i32 [[V:%.*]], i64 0
108-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
109-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
110-
; CHECK-NEXT: ret i32 [[R]]
107+
; CHECK-NEXT: ret i32 [[V:%.*]]
111108
;
112109
%in = insertelement <vscale x 4 x i32> poison, i32 %v, i32 0
113110
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer

llvm/test/Transforms/InstCombine/vscale_extractelement.ll

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,12 @@ define i8 @extractelement_bitcast_useless_insert(<vscale x 2 x i32> %a, i32 %x)
5353
ret i8 %r
5454
}
5555

56+
; while in these tests it may be that the extract is out-of-bounds,
57+
; any valid index is going to yield %v (because the mask is all-zeros).
58+
5659
define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
5760
; CHECK-LABEL: @extractelement_shuffle_maybe_out_of_range(
58-
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
59-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
60-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4
61-
; CHECK-NEXT: ret i32 [[R]]
61+
; CHECK-NEXT: ret i32 [[V:%.*]]
6262
;
6363
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
6464
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer
@@ -68,10 +68,7 @@ define i32 @extractelement_shuffle_maybe_out_of_range(i32 %v) {
6868

6969
define i32 @extractelement_shuffle_invalid_index(i32 %v) {
7070
; CHECK-LABEL: @extractelement_shuffle_invalid_index(
71-
; CHECK-NEXT: [[IN:%.*]] = insertelement <vscale x 4 x i32> undef, i32 [[V:%.*]], i64 0
72-
; CHECK-NEXT: [[SPLAT:%.*]] = shufflevector <vscale x 4 x i32> [[IN]], <vscale x 4 x i32> poison, <vscale x 4 x i32> zeroinitializer
73-
; CHECK-NEXT: [[R:%.*]] = extractelement <vscale x 4 x i32> [[SPLAT]], i64 4294967295
74-
; CHECK-NEXT: ret i32 [[R]]
71+
; CHECK-NEXT: ret i32 [[V:%.*]]
7572
;
7673
%in = insertelement <vscale x 4 x i32> undef, i32 %v, i32 0
7774
%splat = shufflevector <vscale x 4 x i32> %in, <vscale x 4 x i32> undef, <vscale x 4 x i32> zeroinitializer

0 commit comments

Comments
 (0)