Skip to content

Commit 7ea3714

Browse files
[AArch64] Extend performActiveLaneMaskCombine for more than two extracts (#146725)
The combine was added to find a get.active.lane.mask used by two extract subvectors and try to replace it with the paired whilelo instruction. This extends the combine to cover cases where there are more than two extracts.
1 parent 7d92756 commit 7ea3714

File tree

2 files changed

+153
-59
lines changed

2 files changed

+153
-59
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 41 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -18170,53 +18170,65 @@ performActiveLaneMaskCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI,
1817018170
(!ST->hasSVE2p1() && !(ST->hasSME2() && ST->isStreaming())))
1817118171
return SDValue();
1817218172

18173-
if (!N->hasNUsesOfValue(2, 0))
18173+
unsigned NumUses = N->use_size();
18174+
auto MaskEC = N->getValueType(0).getVectorElementCount();
18175+
if (!MaskEC.isKnownMultipleOf(NumUses))
1817418176
return SDValue();
1817518177

18176-
const uint64_t HalfSize = N->getValueType(0).getVectorMinNumElements() / 2;
18177-
if (HalfSize < 2)
18178+
ElementCount ExtMinEC = MaskEC.divideCoefficientBy(NumUses);
18179+
if (ExtMinEC.getKnownMinValue() < 2)
1817818180
return SDValue();
1817918181

18180-
auto It = N->user_begin();
18181-
SDNode *Lo = *It++;
18182-
SDNode *Hi = *It;
18182+
SmallVector<SDNode *> Extracts(NumUses, nullptr);
18183+
for (SDNode *Use : N->users()) {
18184+
if (Use->getOpcode() != ISD::EXTRACT_SUBVECTOR)
18185+
return SDValue();
1818318186

18184-
if (Lo->getOpcode() != ISD::EXTRACT_SUBVECTOR ||
18185-
Hi->getOpcode() != ISD::EXTRACT_SUBVECTOR)
18186-
return SDValue();
18187+
// Ensure the extract type is correct (e.g. if NumUses is 4 and
18188+
// the mask return type is nxv8i1, each extract should be nxv2i1.
18189+
if (Use->getValueType(0).getVectorElementCount() != ExtMinEC)
18190+
return SDValue();
1818718191

18188-
uint64_t OffLo = Lo->getConstantOperandVal(1);
18189-
uint64_t OffHi = Hi->getConstantOperandVal(1);
18192+
// There should be exactly one extract for each part of the mask.
18193+
unsigned Offset = Use->getConstantOperandVal(1);
18194+
unsigned Part = Offset / ExtMinEC.getKnownMinValue();
18195+
if (Extracts[Part] != nullptr)
18196+
return SDValue();
1819018197

18191-
if (OffLo > OffHi) {
18192-
std::swap(Lo, Hi);
18193-
std::swap(OffLo, OffHi);
18198+
Extracts[Part] = Use;
1819418199
}
1819518200

18196-
if (OffLo != 0 || OffHi != HalfSize)
18197-
return SDValue();
18198-
18199-
EVT HalfVec = Lo->getValueType(0);
18200-
if (HalfVec != Hi->getValueType(0) ||
18201-
HalfVec.getVectorElementCount() != ElementCount::getScalable(HalfSize))
18202-
return SDValue();
18203-
1820418201
SelectionDAG &DAG = DCI.DAG;
1820518202
SDLoc DL(N);
1820618203
SDValue ID =
1820718204
DAG.getTargetConstant(Intrinsic::aarch64_sve_whilelo_x2, DL, MVT::i64);
18205+
1820818206
SDValue Idx = N->getOperand(0);
1820918207
SDValue TC = N->getOperand(1);
18210-
if (Idx.getValueType() != MVT::i64) {
18211-
Idx = DAG.getZExtOrTrunc(Idx, DL, MVT::i64);
18212-
TC = DAG.getZExtOrTrunc(TC, DL, MVT::i64);
18208+
EVT OpVT = Idx.getValueType();
18209+
if (OpVT != MVT::i64) {
18210+
Idx = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, Idx);
18211+
TC = DAG.getNode(ISD::ZERO_EXTEND, DL, MVT::i64, TC);
1821318212
}
18213+
18214+
// Create the whilelo_x2 intrinsics from each pair of extracts
18215+
EVT ExtVT = Extracts[0]->getValueType(0);
1821418216
auto R =
18215-
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL,
18216-
{Lo->getValueType(0), Hi->getValueType(0)}, {ID, Idx, TC});
18217+
DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
18218+
DCI.CombineTo(Extracts[0], R.getValue(0));
18219+
DCI.CombineTo(Extracts[1], R.getValue(1));
18220+
18221+
if (NumUses == 2)
18222+
return SDValue(N, 0);
1821718223

18218-
DCI.CombineTo(Lo, R.getValue(0));
18219-
DCI.CombineTo(Hi, R.getValue(1));
18224+
auto Elts = DAG.getElementCount(DL, OpVT, ExtVT.getVectorElementCount() * 2);
18225+
for (unsigned I = 2; I < NumUses; I += 2) {
18226+
// After the first whilelo_x2, we need to increment the starting value.
18227+
Idx = DAG.getNode(ISD::UADDSAT, DL, OpVT, Idx, Elts);
18228+
R = DAG.getNode(ISD::INTRINSIC_WO_CHAIN, DL, {ExtVT, ExtVT}, {ID, Idx, TC});
18229+
DCI.CombineTo(Extracts[I], R.getValue(0));
18230+
DCI.CombineTo(Extracts[I + 1], R.getValue(1));
18231+
}
1822018232

1822118233
return SDValue(N, 0);
1822218234
}

llvm/test/CodeGen/AArch64/get-active-lane-mask-extract.ll

Lines changed: 112 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,65 @@ define void @test_boring_case_2x2bit_mask(i64 %i, i64 %n) #0 {
8686
ret void
8787
}
8888

89+
define void @test_legal_4x2bit_mask(i64 %i, i64 %n) #0 {
90+
; CHECK-SVE-LABEL: test_legal_4x2bit_mask:
91+
; CHECK-SVE: // %bb.0:
92+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
93+
; CHECK-SVE-NEXT: punpkhi p1.h, p0.b
94+
; CHECK-SVE-NEXT: punpklo p4.h, p0.b
95+
; CHECK-SVE-NEXT: punpkhi p3.h, p1.b
96+
; CHECK-SVE-NEXT: punpklo p2.h, p1.b
97+
; CHECK-SVE-NEXT: punpklo p0.h, p4.b
98+
; CHECK-SVE-NEXT: punpkhi p1.h, p4.b
99+
; CHECK-SVE-NEXT: b use
100+
;
101+
; CHECK-SVE2p1-SME2-LABEL: test_legal_4x2bit_mask:
102+
; CHECK-SVE2p1-SME2: // %bb.0:
103+
; CHECK-SVE2p1-SME2-NEXT: cntw x8
104+
; CHECK-SVE2p1-SME2-NEXT: adds x8, x0, x8
105+
; CHECK-SVE2p1-SME2-NEXT: csinv x8, x8, xzr, lo
106+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p0.d, p1.d }, x0, x1
107+
; CHECK-SVE2p1-SME2-NEXT: whilelo { p2.d, p3.d }, x8, x1
108+
; CHECK-SVE2p1-SME2-NEXT: b use
109+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
110+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
111+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
112+
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
113+
%v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
114+
tail call void @use(<vscale x 2 x i1> %v3, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v0)
115+
ret void
116+
}
117+
118+
; Negative test where the extract types are correct but we are not extracting all parts of the mask
119+
; Note: We could still create a whilelo_x2 for the first two extracts, but we don't expect this case often yet.
120+
define void @test_partial_extract_correct_types(i64 %i, i64 %n) #0 {
121+
; CHECK-SVE-LABEL: test_partial_extract_correct_types:
122+
; CHECK-SVE: // %bb.0:
123+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
124+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
125+
; CHECK-SVE-NEXT: punpkhi p2.h, p0.b
126+
; CHECK-SVE-NEXT: punpklo p0.h, p1.b
127+
; CHECK-SVE-NEXT: punpkhi p1.h, p1.b
128+
; CHECK-SVE-NEXT: punpkhi p2.h, p2.b
129+
; CHECK-SVE-NEXT: b use
130+
;
131+
; CHECK-SVE2p1-SME2-LABEL: test_partial_extract_correct_types:
132+
; CHECK-SVE2p1-SME2: // %bb.0:
133+
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
134+
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
135+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p0.b
136+
; CHECK-SVE2p1-SME2-NEXT: punpklo p0.h, p1.b
137+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p1.h, p1.b
138+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p2.h, p2.b
139+
; CHECK-SVE2p1-SME2-NEXT: b use
140+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
141+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
142+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
143+
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
144+
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2)
145+
ret void
146+
}
147+
89148
; Negative test for when not extracting exactly two halves of the source vector
90149
define void @test_partial_extract(i64 %i, i64 %n) #0 {
91150
; CHECK-SVE-LABEL: test_partial_extract:
@@ -116,57 +175,80 @@ define void @test_partial_extract(i64 %i, i64 %n) #0 {
116175
define void @test_fixed_extract(i64 %i, i64 %n) #0 {
117176
; CHECK-SVE-LABEL: test_fixed_extract:
118177
; CHECK-SVE: // %bb.0:
119-
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
178+
; CHECK-SVE-NEXT: whilelo p0.s, x0, x1
120179
; CHECK-SVE-NEXT: cset w8, mi
121-
; CHECK-SVE-NEXT: mov z0.h, p0/z, #1 // =0x1
122-
; CHECK-SVE-NEXT: umov w9, v0.h[4]
123-
; CHECK-SVE-NEXT: umov w10, v0.h[1]
124-
; CHECK-SVE-NEXT: umov w11, v0.h[5]
180+
; CHECK-SVE-NEXT: mov z1.s, p0/z, #1 // =0x1
125181
; CHECK-SVE-NEXT: fmov s0, w8
126-
; CHECK-SVE-NEXT: fmov s1, w9
127-
; CHECK-SVE-NEXT: mov v0.s[1], w10
182+
; CHECK-SVE-NEXT: mov v0.s[1], v1.s[1]
183+
; CHECK-SVE-NEXT: ext z1.b, z1.b, z1.b, #8
128184
; CHECK-SVE-NEXT: // kill: def $d0 killed $d0 killed $q0
129-
; CHECK-SVE-NEXT: mov v1.s[1], w11
130-
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 killed $q1
185+
; CHECK-SVE-NEXT: // kill: def $d1 killed $d1 killed $z1
131186
; CHECK-SVE-NEXT: b use
132187
;
133188
; CHECK-SVE2p1-LABEL: test_fixed_extract:
134189
; CHECK-SVE2p1: // %bb.0:
135-
; CHECK-SVE2p1-NEXT: whilelo p0.h, x0, x1
190+
; CHECK-SVE2p1-NEXT: whilelo p0.s, x0, x1
136191
; CHECK-SVE2p1-NEXT: cset w8, mi
137-
; CHECK-SVE2p1-NEXT: mov z0.h, p0/z, #1 // =0x1
138-
; CHECK-SVE2p1-NEXT: umov w9, v0.h[4]
139-
; CHECK-SVE2p1-NEXT: umov w10, v0.h[1]
140-
; CHECK-SVE2p1-NEXT: umov w11, v0.h[5]
192+
; CHECK-SVE2p1-NEXT: mov z1.s, p0/z, #1 // =0x1
141193
; CHECK-SVE2p1-NEXT: fmov s0, w8
142-
; CHECK-SVE2p1-NEXT: fmov s1, w9
143-
; CHECK-SVE2p1-NEXT: mov v0.s[1], w10
194+
; CHECK-SVE2p1-NEXT: mov v0.s[1], v1.s[1]
195+
; CHECK-SVE2p1-NEXT: ext z1.b, z1.b, z1.b, #8
144196
; CHECK-SVE2p1-NEXT: // kill: def $d0 killed $d0 killed $q0
145-
; CHECK-SVE2p1-NEXT: mov v1.s[1], w11
146-
; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $q1
197+
; CHECK-SVE2p1-NEXT: // kill: def $d1 killed $d1 killed $z1
147198
; CHECK-SVE2p1-NEXT: b use
148199
;
149200
; CHECK-SME2-LABEL: test_fixed_extract:
150201
; CHECK-SME2: // %bb.0:
151-
; CHECK-SME2-NEXT: whilelo p0.h, x0, x1
202+
; CHECK-SME2-NEXT: whilelo p0.s, x0, x1
152203
; CHECK-SME2-NEXT: cset w8, mi
153-
; CHECK-SME2-NEXT: mov z0.h, p0/z, #1 // =0x1
154-
; CHECK-SME2-NEXT: mov z1.h, z0.h[1]
155-
; CHECK-SME2-NEXT: mov z2.h, z0.h[5]
156-
; CHECK-SME2-NEXT: mov z3.h, z0.h[4]
157-
; CHECK-SME2-NEXT: fmov s0, w8
158-
; CHECK-SME2-NEXT: zip1 z0.s, z0.s, z1.s
159-
; CHECK-SME2-NEXT: zip1 z1.s, z3.s, z2.s
160-
; CHECK-SME2-NEXT: // kill: def $d0 killed $d0 killed $z0
204+
; CHECK-SME2-NEXT: mov z1.s, p0/z, #1 // =0x1
205+
; CHECK-SME2-NEXT: fmov s2, w8
206+
; CHECK-SME2-NEXT: mov z0.s, z1.s[1]
207+
; CHECK-SME2-NEXT: ext z1.b, z1.b, z1.b, #8
161208
; CHECK-SME2-NEXT: // kill: def $d1 killed $d1 killed $z1
209+
; CHECK-SME2-NEXT: zip1 z0.s, z2.s, z0.s
210+
; CHECK-SME2-NEXT: // kill: def $d0 killed $d0 killed $z0
162211
; CHECK-SME2-NEXT: b use
163-
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
164-
%v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 0)
165-
%v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
212+
%r = call <vscale x 4 x i1> @llvm.get.active.lane.mask.nxv4i1.i64(i64 %i, i64 %n)
213+
%v0 = call <2 x i1> @llvm.vector.extract.v2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 0)
214+
%v1 = call <2 x i1> @llvm.vector.extract.v2i1.nxv4i1.i64(<vscale x 4 x i1> %r, i64 2)
166215
tail call void @use(<2 x i1> %v0, <2 x i1> %v1)
167216
ret void
168217
}
169218

219+
; Negative test where the number of extracts is right, but they cannot be combined because
220+
; there is not an extract for each part
221+
define void @test_4x2bit_duplicate_mask(i64 %i, i64 %n) #0 {
222+
; CHECK-SVE-LABEL: test_4x2bit_duplicate_mask:
223+
; CHECK-SVE: // %bb.0:
224+
; CHECK-SVE-NEXT: whilelo p0.h, x0, x1
225+
; CHECK-SVE-NEXT: punpklo p1.h, p0.b
226+
; CHECK-SVE-NEXT: punpkhi p3.h, p0.b
227+
; CHECK-SVE-NEXT: punpkhi p0.h, p1.b
228+
; CHECK-SVE-NEXT: punpklo p2.h, p3.b
229+
; CHECK-SVE-NEXT: punpkhi p3.h, p3.b
230+
; CHECK-SVE-NEXT: mov p1.b, p0.b
231+
; CHECK-SVE-NEXT: b use
232+
;
233+
; CHECK-SVE2p1-SME2-LABEL: test_4x2bit_duplicate_mask:
234+
; CHECK-SVE2p1-SME2: // %bb.0:
235+
; CHECK-SVE2p1-SME2-NEXT: whilelo p0.h, x0, x1
236+
; CHECK-SVE2p1-SME2-NEXT: punpklo p1.h, p0.b
237+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p0.b
238+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p0.h, p1.b
239+
; CHECK-SVE2p1-SME2-NEXT: punpklo p2.h, p3.b
240+
; CHECK-SVE2p1-SME2-NEXT: punpkhi p3.h, p3.b
241+
; CHECK-SVE2p1-SME2-NEXT: mov p1.b, p0.b
242+
; CHECK-SVE2p1-SME2-NEXT: b use
243+
%r = call <vscale x 8 x i1> @llvm.get.active.lane.mask.nxv8i1.i64(i64 %i, i64 %n)
244+
%v0 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
245+
%v1 = call <vscale x 2 x i1> @llvm.vector.extract.nxv2i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 2)
246+
%v2 = call <vscale x 2 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 4)
247+
%v3 = call <vscale x 2 x i1> @llvm.vector.extract.nxv4i1.nxv8i1.i64(<vscale x 8 x i1> %r, i64 6)
248+
tail call void @use(<vscale x 2 x i1> %v0, <vscale x 2 x i1> %v1, <vscale x 2 x i1> %v2, <vscale x 2 x i1> %v3)
249+
ret void
250+
}
251+
170252
; Illegal Types
171253

172254
define void @test_2x16bit_mask_with_32bit_index_and_trip_count(i32 %i, i32 %n) #0 {

0 commit comments

Comments
 (0)