Skip to content

Commit 31cb807

Browse files
authored
[SanbdoxVec][BottomUpVec] Fix diamond shuffle with multiple vector inputs (llvm#126965)
When the operand comes from multiple inputs then we need additional packing code. When the operands are scalar then we can use a single InsertElementInst. But when the operands are vectors then we need a chain of ExtractElementInst and InsertElementInst instructions to insert the vector value into the destination vector. This is what this patch implements.
1 parent 3e02069 commit 31cb807

File tree

3 files changed

+63
-8
lines changed

3 files changed

+63
-8
lines changed

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Legality.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -203,13 +203,12 @@ LegalityAnalysis::getHowToCollectValues(ArrayRef<Value *> Bndl) const {
203203
SmallVector<CollectDescr::ExtractElementDescr, 4> Vec;
204204
Vec.reserve(Bndl.size());
205205
for (auto [Elm, V] : enumerate(Bndl)) {
206-
uint32_t VLanes = VecUtils::getNumLanes(V);
207206
if (auto *VecOp = IMaps.getVectorForOrig(V)) {
208207
// If there is a vector containing `V`, then get the lane it came from.
209208
std::optional<int> ExtractIdxOpt = IMaps.getOrigLane(VecOp, V);
210209
// This could be a vector, like <2 x float> in which case the mask needs
211210
// to enumerate all lanes.
212-
for (unsigned Ln = 0; Ln != VLanes; ++Ln)
211+
for (unsigned Ln = 0, Lanes = VecUtils::getNumLanes(V); Ln != Lanes; ++Ln)
213212
Vec.emplace_back(VecOp, ExtractIdxOpt ? *ExtractIdxOpt + Ln : -1);
214213
} else {
215214
Vec.emplace_back(V);

llvm/lib/Transforms/Vectorize/SandboxVectorizer/Passes/BottomUpVec.cpp

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
335335
case LegalityResultID::DiamondReuseMultiInput: {
336336
const auto &Descr =
337337
cast<DiamondReuseMultiInput>(LegalityRes).getCollectDescr();
338-
Type *ResTy = FixedVectorType::get(Bndl[0]->getType(), Bndl.size());
338+
Type *ResTy = VecUtils::getWideType(Bndl[0]->getType(), Bndl.size());
339339

340340
// TODO: Try to get WhereIt without creating a vector.
341341
SmallVector<Value *, 4> DescrInstrs;
@@ -347,7 +347,8 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
347347
getInsertPointAfterInstrs(DescrInstrs, UserBB);
348348

349349
Value *LastV = PoisonValue::get(ResTy);
350-
for (auto [Lane, ElmDescr] : enumerate(Descr.getDescrs())) {
350+
unsigned Lane = 0;
351+
for (const auto &ElmDescr : Descr.getDescrs()) {
351352
Value *VecOp = ElmDescr.getValue();
352353
Context &Ctx = VecOp->getContext();
353354
Value *ValueToInsert;
@@ -359,10 +360,32 @@ Value *BottomUpVec::vectorizeRec(ArrayRef<Value *> Bndl,
359360
} else {
360361
ValueToInsert = VecOp;
361362
}
362-
ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
363-
Value *Ins = InsertElementInst::create(LastV, ValueToInsert, LaneC,
364-
WhereIt, Ctx, "VIns");
365-
LastV = Ins;
363+
auto NumLanesToInsert = VecUtils::getNumLanes(ValueToInsert);
364+
if (NumLanesToInsert == 1) {
365+
// If we are inserting a scalar element then we need a single insert.
366+
// %VIns = insert %DstVec, %SrcScalar, Lane
367+
ConstantInt *LaneC = ConstantInt::get(Type::getInt32Ty(Ctx), Lane);
368+
LastV = InsertElementInst::create(LastV, ValueToInsert, LaneC, WhereIt,
369+
Ctx, "VIns");
370+
} else {
371+
// If we are inserting a vector element then we need to extract and
372+
// insert each vector element one by one with a chain of extracts and
373+
// inserts, for example:
374+
// %VExt0 = extract %SrcVec, 0
375+
// %VIns0 = insert %DstVec, %Vect0, Lane + 0
376+
// %VExt1 = extract %SrcVec, 1
377+
// %VIns1 = insert %VIns0, %Vect0, Lane + 1
378+
for (unsigned LnCnt = 0; LnCnt != NumLanesToInsert; ++LnCnt) {
379+
auto *ExtrIdxC = ConstantInt::get(Type::getInt32Ty(Ctx), LnCnt);
380+
auto *ExtrI = ExtractElementInst::create(ValueToInsert, ExtrIdxC,
381+
WhereIt, Ctx, "VExt");
382+
unsigned InsLane = Lane + LnCnt;
383+
auto *InsLaneC = ConstantInt::get(Type::getInt32Ty(Ctx), InsLane);
384+
LastV = InsertElementInst::create(LastV, ExtrI, InsLaneC, WhereIt,
385+
Ctx, "VIns");
386+
}
387+
}
388+
Lane += NumLanesToInsert;
366389
}
367390
NewVec = LastV;
368391
break;

llvm/test/Transforms/SandboxVectorizer/bottomup_basic.ll

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,39 @@ define void @diamondMultiInput(ptr %ptr, ptr %ptrX) {
292292
ret void
293293
}
294294

295+
; Same but vectorizing <2 x float> vectors instead of scalars.
296+
define void @diamondMultiInputVector(ptr %ptr, ptr %ptrX) {
297+
; CHECK-LABEL: define void @diamondMultiInputVector(
298+
; CHECK-SAME: ptr [[PTR:%.*]], ptr [[PTRX:%.*]]) {
299+
; CHECK-NEXT: [[PTR0:%.*]] = getelementptr <2 x float>, ptr [[PTR]], i32 0
300+
; CHECK-NEXT: [[LDX:%.*]] = load <2 x float>, ptr [[PTRX]], align 8
301+
; CHECK-NEXT: [[VECL:%.*]] = load <4 x float>, ptr [[PTR0]], align 8
302+
; CHECK-NEXT: [[VEXT:%.*]] = extractelement <2 x float> [[LDX]], i32 0
303+
; CHECK-NEXT: [[INSI:%.*]] = insertelement <4 x float> poison, float [[VEXT]], i32 0
304+
; CHECK-NEXT: [[VEXT1:%.*]] = extractelement <2 x float> [[LDX]], i32 1
305+
; CHECK-NEXT: [[INSI2:%.*]] = insertelement <4 x float> [[INSI]], float [[VEXT1]], i32 1
306+
; CHECK-NEXT: [[VEXT3:%.*]] = extractelement <4 x float> [[VECL]], i32 0
307+
; CHECK-NEXT: [[VINS4:%.*]] = insertelement <4 x float> [[INSI2]], float [[VEXT3]], i32 2
308+
; CHECK-NEXT: [[VEXT4:%.*]] = extractelement <4 x float> [[VECL]], i32 1
309+
; CHECK-NEXT: [[VINS5:%.*]] = insertelement <4 x float> [[VINS4]], float [[VEXT4]], i32 3
310+
; CHECK-NEXT: [[VEC:%.*]] = fsub <4 x float> [[VECL]], [[VINS5]]
311+
; CHECK-NEXT: store <4 x float> [[VEC]], ptr [[PTR0]], align 8
312+
; CHECK-NEXT: ret void
313+
;
314+
%ptr0 = getelementptr <2 x float>, ptr %ptr, i32 0
315+
%ptr1 = getelementptr <2 x float>, ptr %ptr, i32 1
316+
%ld0 = load <2 x float>, ptr %ptr0
317+
%ld1 = load <2 x float>, ptr %ptr1
318+
319+
%ldX = load <2 x float>, ptr %ptrX
320+
321+
%sub0 = fsub <2 x float> %ld0, %ldX
322+
%sub1 = fsub <2 x float> %ld1, %ld0
323+
store <2 x float> %sub0, ptr %ptr0
324+
store <2 x float> %sub1, ptr %ptr1
325+
ret void
326+
}
327+
295328
define void @diamondWithConstantVector(ptr %ptr) {
296329
; CHECK-LABEL: define void @diamondWithConstantVector(
297330
; CHECK-SAME: ptr [[PTR:%.*]]) {

0 commit comments

Comments
 (0)