Skip to content

Commit bc8dad1

Browse files
authored
[VPlan] Emit VPVectorEndPointerRecipe for reverse interleave pointer adjustment (#144864)
A reverse interleave access is essentially composed of multiple load/store operations with same negative stride, and their addresses are based on the last lane address of member 0 in the interleaved group. Currently, we already have VPVectorEndPointerRecipe for computing the last lane address of consecutive reverse memory accesses. This patch extends VPVectorEndPointerRecipe to support constant stride and extracts the reverse interleave group address adjustment from VPInterleaveRecipe::execute, replacing it with a VPVectorEndPointerRecipe. The final goal is to support interleaved accesses with EVL tail folding. Given that VPInterleaveRecipe is large and tightly coupled — combining both load and store, and embedding operations like reverse pointer adjustion (GEP), widen load/store, deinterleave/interleave, and reversal — breaking it down into smaller, dedicated recipes may allow VPlanTransforms::tryAddExplicitVectorLength to lower them into EVL-aware form more effectively. One foreseeable challenge is that VPlanTransforms::convertToConcreteRecipes currently runs after tryAddExplicitVectorLength, so decomposing VPInterleaveRecipe will likely need to happen earlier in the pipeline to be effective.
1 parent 6e1e89e commit bc8dad1

File tree

9 files changed

+137
-112
lines changed

9 files changed

+137
-112
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7769,8 +7769,9 @@ VPRecipeBuilder::tryToWidenMemory(Instruction *I, ArrayRef<VPValue *> Operands,
77697769
(CM.foldTailByMasking() || !GEP || !GEP->isInBounds())
77707770
? GEPNoWrapFlags::none()
77717771
: GEPNoWrapFlags::inBounds();
7772-
VectorPtr = new VPVectorEndPointerRecipe(
7773-
Ptr, &Plan.getVF(), getLoadStoreType(I), Flags, I->getDebugLoc());
7772+
VectorPtr =
7773+
new VPVectorEndPointerRecipe(Ptr, &Plan.getVF(), getLoadStoreType(I),
7774+
/*Stride*/ -1, Flags, I->getDebugLoc());
77747775
} else {
77757776
VectorPtr = new VPVectorPointerRecipe(Ptr, getLoadStoreType(I),
77767777
GEP ? GEP->getNoWrapFlags()

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,17 +1704,23 @@ class VPWidenGEPRecipe : public VPRecipeWithIRFlags {
17041704

17051705
/// A recipe to compute a pointer to the last element of each part of a widened
17061706
/// memory access for widened memory accesses of IndexedTy. Used for
1707-
/// VPWidenMemoryRecipes that are reversed.
1707+
/// VPWidenMemoryRecipes or VPInterleaveRecipes that are reversed.
17081708
class VPVectorEndPointerRecipe : public VPRecipeWithIRFlags,
17091709
public VPUnrollPartAccessor<2> {
17101710
Type *IndexedTy;
17111711

1712+
/// The constant stride of the pointer computed by this recipe, expressed in
1713+
/// units of IndexedTy.
1714+
int64_t Stride;
1715+
17121716
public:
17131717
VPVectorEndPointerRecipe(VPValue *Ptr, VPValue *VF, Type *IndexedTy,
1714-
GEPNoWrapFlags GEPFlags, DebugLoc DL)
1718+
int64_t Stride, GEPNoWrapFlags GEPFlags, DebugLoc DL)
17151719
: VPRecipeWithIRFlags(VPDef::VPVectorEndPointerSC,
17161720
ArrayRef<VPValue *>({Ptr, VF}), GEPFlags, DL),
1717-
IndexedTy(IndexedTy) {}
1721+
IndexedTy(IndexedTy), Stride(Stride) {
1722+
assert(Stride < 0 && "Stride must be negative");
1723+
}
17181724

17191725
VP_CLASSOF_IMPL(VPDef::VPVectorEndPointerSC)
17201726

@@ -1746,7 +1752,8 @@ class VPVectorEndPointerRecipe : public VPRecipeWithIRFlags,
17461752

17471753
VPVectorEndPointerRecipe *clone() override {
17481754
return new VPVectorEndPointerRecipe(getOperand(0), getVFValue(), IndexedTy,
1749-
getGEPNoWrapFlags(), getDebugLoc());
1755+
Stride, getGEPNoWrapFlags(),
1756+
getDebugLoc());
17501757
}
17511758

17521759
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 11 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2309,31 +2309,34 @@ void VPWidenGEPRecipe::print(raw_ostream &O, const Twine &Indent,
23092309
}
23102310
#endif
23112311

2312-
static Type *getGEPIndexTy(bool IsScalable, bool IsReverse,
2312+
static Type *getGEPIndexTy(bool IsScalable, bool IsReverse, bool IsUnitStride,
23132313
unsigned CurrentPart, IRBuilderBase &Builder) {
23142314
// Use i32 for the gep index type when the value is constant,
23152315
// or query DataLayout for a more suitable index type otherwise.
23162316
const DataLayout &DL = Builder.GetInsertBlock()->getDataLayout();
2317-
return IsScalable && (IsReverse || CurrentPart > 0)
2317+
return !IsUnitStride || (IsScalable && (IsReverse || CurrentPart > 0))
23182318
? DL.getIndexType(Builder.getPtrTy(0))
23192319
: Builder.getInt32Ty();
23202320
}
23212321

23222322
void VPVectorEndPointerRecipe::execute(VPTransformState &State) {
23232323
auto &Builder = State.Builder;
23242324
unsigned CurrentPart = getUnrollPart(*this);
2325+
bool IsUnitStride = Stride == 1 || Stride == -1;
23252326
Type *IndexTy = getGEPIndexTy(State.VF.isScalable(), /*IsReverse*/ true,
2326-
CurrentPart, Builder);
2327+
IsUnitStride, CurrentPart, Builder);
23272328

23282329
// The wide store needs to start at the last vector element.
23292330
Value *RunTimeVF = State.get(getVFValue(), VPLane(0));
23302331
if (IndexTy != RunTimeVF->getType())
23312332
RunTimeVF = Builder.CreateZExtOrTrunc(RunTimeVF, IndexTy);
2332-
// NumElt = -CurrentPart * RunTimeVF
2333+
// NumElt = Stride * CurrentPart * RunTimeVF
23332334
Value *NumElt = Builder.CreateMul(
2334-
ConstantInt::get(IndexTy, -(int64_t)CurrentPart), RunTimeVF);
2335-
// LastLane = 1 - RunTimeVF
2336-
Value *LastLane = Builder.CreateSub(ConstantInt::get(IndexTy, 1), RunTimeVF);
2335+
ConstantInt::get(IndexTy, Stride * (int64_t)CurrentPart), RunTimeVF);
2336+
// LastLane = Stride * (RunTimeVF - 1)
2337+
Value *LastLane = Builder.CreateSub(RunTimeVF, ConstantInt::get(IndexTy, 1));
2338+
if (Stride != 1)
2339+
LastLane = Builder.CreateMul(ConstantInt::get(IndexTy, Stride), LastLane);
23372340
Value *Ptr = State.get(getOperand(0), VPLane(0));
23382341
Value *ResultPtr =
23392342
Builder.CreateGEP(IndexedTy, Ptr, NumElt, "", getGEPNoWrapFlags());
@@ -2358,7 +2361,7 @@ void VPVectorPointerRecipe::execute(VPTransformState &State) {
23582361
auto &Builder = State.Builder;
23592362
unsigned CurrentPart = getUnrollPart(*this);
23602363
Type *IndexTy = getGEPIndexTy(State.VF.isScalable(), /*IsReverse*/ false,
2361-
CurrentPart, Builder);
2364+
/*IsUnitStride*/ true, CurrentPart, Builder);
23622365
Value *Ptr = State.get(getOperand(0), VPLane(0));
23632366

23642367
Value *Increment = createStepForVF(Builder, IndexTy, State.VF, CurrentPart);
@@ -3425,25 +3428,6 @@ void VPInterleaveRecipe::execute(VPTransformState &State) {
34253428
if (auto *I = dyn_cast<Instruction>(ResAddr))
34263429
State.setDebugLocFrom(I->getDebugLoc());
34273430

3428-
// If the group is reverse, adjust the index to refer to the last vector lane
3429-
// instead of the first. We adjust the index from the first vector lane,
3430-
// rather than directly getting the pointer for lane VF - 1, because the
3431-
// pointer operand of the interleaved access is supposed to be uniform.
3432-
if (Group->isReverse()) {
3433-
Value *RuntimeVF =
3434-
getRuntimeVF(State.Builder, State.Builder.getInt32Ty(), State.VF);
3435-
Value *Index =
3436-
State.Builder.CreateSub(RuntimeVF, State.Builder.getInt32(1));
3437-
Index = State.Builder.CreateMul(Index,
3438-
State.Builder.getInt32(Group->getFactor()));
3439-
Index = State.Builder.CreateNeg(Index);
3440-
3441-
bool InBounds = false;
3442-
if (auto *Gep = dyn_cast<GetElementPtrInst>(ResAddr->stripPointerCasts()))
3443-
InBounds = Gep->isInBounds();
3444-
ResAddr = State.Builder.CreateGEP(ScalarTy, ResAddr, Index, "", InBounds);
3445-
}
3446-
34473431
State.setDebugLocFrom(getDebugLoc());
34483432
Value *PoisonVec = PoisonValue::get(VecTy);
34493433

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2482,23 +2482,23 @@ void VPlanTransforms::createInterleaveGroups(
24822482
auto *InsertPos =
24832483
cast<VPWidenMemoryRecipe>(RecipeBuilder.getRecipe(IRInsertPos));
24842484

2485+
bool InBounds = false;
2486+
if (auto *Gep = dyn_cast<GetElementPtrInst>(
2487+
getLoadStorePointerOperand(IRInsertPos)->stripPointerCasts()))
2488+
InBounds = Gep->isInBounds();
2489+
24852490
// Get or create the start address for the interleave group.
24862491
auto *Start =
24872492
cast<VPWidenMemoryRecipe>(RecipeBuilder.getRecipe(IG->getMember(0)));
24882493
VPValue *Addr = Start->getAddr();
24892494
VPRecipeBase *AddrDef = Addr->getDefiningRecipe();
24902495
if (AddrDef && !VPDT.properlyDominates(AddrDef, InsertPos)) {
2491-
// TODO: Hoist Addr's defining recipe (and any operands as needed) to
2492-
// InsertPos or sink loads above zero members to join it.
2493-
bool InBounds = false;
2494-
if (auto *Gep = dyn_cast<GetElementPtrInst>(
2495-
getLoadStorePointerOperand(IRInsertPos)->stripPointerCasts()))
2496-
InBounds = Gep->isInBounds();
2497-
24982496
// We cannot re-use the address of member zero because it does not
24992497
// dominate the insert position. Instead, use the address of the insert
25002498
// position and create a PtrAdd adjusting it to the address of member
25012499
// zero.
2500+
// TODO: Hoist Addr's defining recipe (and any operands as needed) to
2501+
// InsertPos or sink loads above zero members to join it.
25022502
assert(IG->getIndex(IRInsertPos) != 0 &&
25032503
"index of insert position shouldn't be zero");
25042504
auto &DL = IRInsertPos->getDataLayout();
@@ -2512,6 +2512,19 @@ void VPlanTransforms::createInterleaveGroups(
25122512
Addr = InBounds ? B.createInBoundsPtrAdd(InsertPos->getAddr(), OffsetVPV)
25132513
: B.createPtrAdd(InsertPos->getAddr(), OffsetVPV);
25142514
}
2515+
// If the group is reverse, adjust the index to refer to the last vector
2516+
// lane instead of the first. We adjust the index from the first vector
2517+
// lane, rather than directly getting the pointer for lane VF - 1, because
2518+
// the pointer operand of the interleaved access is supposed to be uniform.
2519+
if (IG->isReverse()) {
2520+
auto *ReversePtr = new VPVectorEndPointerRecipe(
2521+
Addr, &Plan.getVF(), getLoadStoreType(IRInsertPos),
2522+
-(int64_t)IG->getFactor(),
2523+
InBounds ? GEPNoWrapFlags::inBounds() : GEPNoWrapFlags::none(),
2524+
InsertPos->getDebugLoc());
2525+
ReversePtr->insertBefore(InsertPos);
2526+
Addr = ReversePtr;
2527+
}
25152528
auto *VPIG = new VPInterleaveRecipe(IG, Addr, StoredValues,
25162529
InsertPos->getMask(), NeedsMaskForGaps, InsertPos->getDebugLoc());
25172530
VPIG->insertBefore(InsertPos);

llvm/test/Transforms/LoopVectorize/AArch64/sve-interleaved-accesses.ll

Lines changed: 8 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -367,10 +367,8 @@ define void @test_reversed_load2_store2(ptr noalias nocapture readonly %A, ptr n
367367
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 4 x i32> [ [[INDUCTION]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
368368
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = sub i64 1023, [[INDEX]]
369369
; CHECK-NEXT: [[TMP4:%.*]] = getelementptr inbounds [[STRUCT_ST2:%.*]], ptr [[A:%.*]], i64 [[OFFSET_IDX]], i32 0
370-
; CHECK-NEXT: [[TMP5:%.*]] = call i32 @llvm.vscale.i32()
371-
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i32 [[TMP5]], 3
372-
; CHECK-NEXT: [[TMP7:%.*]] = sub nsw i32 2, [[TMP6]]
373-
; CHECK-NEXT: [[TMP8:%.*]] = sext i32 [[TMP7]] to i64
370+
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i64 [[TMP0]], 3
371+
; CHECK-NEXT: [[TMP8:%.*]] = sub nsw i64 2, [[TMP6]]
374372
; CHECK-NEXT: [[TMP9:%.*]] = getelementptr inbounds i32, ptr [[TMP4]], i64 [[TMP8]]
375373
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <vscale x 8 x i32>, ptr [[TMP9]], align 4
376374
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave2.nxv8i32(<vscale x 8 x i32> [[WIDE_VEC]])
@@ -381,10 +379,8 @@ define void @test_reversed_load2_store2(ptr noalias nocapture readonly %A, ptr n
381379
; CHECK-NEXT: [[TMP12:%.*]] = add nsw <vscale x 4 x i32> [[REVERSE]], [[VEC_IND]]
382380
; CHECK-NEXT: [[TMP13:%.*]] = sub nsw <vscale x 4 x i32> [[REVERSE1]], [[VEC_IND]]
383381
; CHECK-NEXT: [[TMP14:%.*]] = getelementptr inbounds [[STRUCT_ST2]], ptr [[B:%.*]], i64 [[OFFSET_IDX]], i32 0
384-
; CHECK-NEXT: [[TMP15:%.*]] = call i32 @llvm.vscale.i32()
385-
; CHECK-NEXT: [[TMP16:%.*]] = shl nuw nsw i32 [[TMP15]], 3
386-
; CHECK-NEXT: [[TMP17:%.*]] = sub nsw i32 2, [[TMP16]]
387-
; CHECK-NEXT: [[TMP18:%.*]] = sext i32 [[TMP17]] to i64
382+
; CHECK-NEXT: [[TMP15:%.*]] = shl nuw nsw i64 [[TMP0]], 3
383+
; CHECK-NEXT: [[TMP18:%.*]] = sub nsw i64 2, [[TMP15]]
388384
; CHECK-NEXT: [[TMP19:%.*]] = getelementptr inbounds i32, ptr [[TMP14]], i64 [[TMP18]]
389385
; CHECK-NEXT: [[REVERSE2:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP12]])
390386
; CHECK-NEXT: [[REVERSE3:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP13]])
@@ -1577,10 +1573,8 @@ define void @interleave_deinterleave_reverse(ptr noalias nocapture readonly %A,
15771573
; CHECK-NEXT: [[VEC_IND:%.*]] = phi <vscale x 4 x i32> [ [[INDUCTION]], [[VECTOR_PH]] ], [ [[VEC_IND_NEXT:%.*]], [[VECTOR_BODY]] ]
15781574
; CHECK-NEXT: [[OFFSET_IDX:%.*]] = sub i64 1023, [[INDEX]]
15791575
; CHECK-NEXT: [[TMP5:%.*]] = getelementptr inbounds [[STRUCT_XYZT:%.*]], ptr [[A:%.*]], i64 [[OFFSET_IDX]], i32 0
1580-
; CHECK-NEXT: [[TMP6:%.*]] = call i32 @llvm.vscale.i32()
1581-
; CHECK-NEXT: [[TMP7:%.*]] = shl nuw nsw i32 [[TMP6]], 4
1582-
; CHECK-NEXT: [[TMP8:%.*]] = sub nsw i32 4, [[TMP7]]
1583-
; CHECK-NEXT: [[TMP9:%.*]] = sext i32 [[TMP8]] to i64
1576+
; CHECK-NEXT: [[TMP6:%.*]] = shl nuw nsw i64 [[TMP0]], 4
1577+
; CHECK-NEXT: [[TMP9:%.*]] = sub nsw i64 4, [[TMP6]]
15841578
; CHECK-NEXT: [[TMP10:%.*]] = getelementptr inbounds i32, ptr [[TMP5]], i64 [[TMP9]]
15851579
; CHECK-NEXT: [[WIDE_VEC:%.*]] = load <vscale x 16 x i32>, ptr [[TMP10]], align 4
15861580
; CHECK-NEXT: [[STRIDED_VEC:%.*]] = call { <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32>, <vscale x 4 x i32> } @llvm.vector.deinterleave4.nxv16i32(<vscale x 16 x i32> [[WIDE_VEC]])
@@ -1597,10 +1591,8 @@ define void @interleave_deinterleave_reverse(ptr noalias nocapture readonly %A,
15971591
; CHECK-NEXT: [[TMP19:%.*]] = mul nsw <vscale x 4 x i32> [[REVERSE4]], [[VEC_IND]]
15981592
; CHECK-NEXT: [[TMP20:%.*]] = shl nuw nsw <vscale x 4 x i32> [[REVERSE5]], [[VEC_IND]]
15991593
; CHECK-NEXT: [[TMP21:%.*]] = getelementptr inbounds [[STRUCT_XYZT]], ptr [[B:%.*]], i64 [[OFFSET_IDX]], i32 0
1600-
; CHECK-NEXT: [[TMP22:%.*]] = call i32 @llvm.vscale.i32()
1601-
; CHECK-NEXT: [[TMP23:%.*]] = shl nuw nsw i32 [[TMP22]], 4
1602-
; CHECK-NEXT: [[TMP24:%.*]] = sub nsw i32 4, [[TMP23]]
1603-
; CHECK-NEXT: [[TMP25:%.*]] = sext i32 [[TMP24]] to i64
1594+
; CHECK-NEXT: [[TMP22:%.*]] = shl nuw nsw i64 [[TMP0]], 4
1595+
; CHECK-NEXT: [[TMP25:%.*]] = sub nsw i64 4, [[TMP22]]
16041596
; CHECK-NEXT: [[TMP26:%.*]] = getelementptr inbounds i32, ptr [[TMP21]], i64 [[TMP25]]
16051597
; CHECK-NEXT: [[REVERSE6:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP17]])
16061598
; CHECK-NEXT: [[REVERSE7:%.*]] = call <vscale x 4 x i32> @llvm.vector.reverse.nxv4i32(<vscale x 4 x i32> [[TMP18]])

0 commit comments

Comments
 (0)