Skip to content

Commit 9b13dfd

Browse files
authored
[LV] Use vscale for tuning to improve branch weight estimates (#144733)
In addBranchWeightToMiddleTerminator we attempt to add branch weights to the middle block terminator. We pessimistically assume vscale=1, whereas we can improve the estimate by using the value of vscale used for tuning.
1 parent 15ab4bb commit 9b13dfd

File tree

4 files changed

+12
-6
lines changed

4 files changed

+12
-6
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7327,9 +7327,11 @@ DenseMap<const SCEV *, Value *> LoopVectorizationPlanner::executePlan(
73277327
OrigLoop->getHeader()->getContext());
73287328
VPlanTransforms::runPass(VPlanTransforms::replicateByVF, BestVPlan, BestVF);
73297329
VPlanTransforms::runPass(VPlanTransforms::materializeBroadcasts, BestVPlan);
7330-
if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator()))
7330+
if (hasBranchWeightMD(*OrigLoop->getLoopLatch()->getTerminator())) {
7331+
std::optional<unsigned> VScale = CM.getVScaleForTuning();
73317332
VPlanTransforms::runPass(VPlanTransforms::addBranchWeightToMiddleTerminator,
7332-
BestVPlan, BestVF);
7333+
BestVPlan, BestVF, VScale);
7334+
}
73337335
VPlanTransforms::optimizeForVFAndUF(BestVPlan, BestVF, BestUF, PSE);
73347336
VPlanTransforms::simplifyRecipes(BestVPlan, *Legal->getWidestInductionType());
73357337
VPlanTransforms::narrowInterleaveGroups(

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3330,8 +3330,8 @@ void VPlanTransforms::narrowInterleaveGroups(VPlan &Plan, ElementCount VF,
33303330

33313331
/// Add branch weight metadata, if the \p Plan's middle block is terminated by a
33323332
/// BranchOnCond recipe.
3333-
void VPlanTransforms::addBranchWeightToMiddleTerminator(VPlan &Plan,
3334-
ElementCount VF) {
3333+
void VPlanTransforms::addBranchWeightToMiddleTerminator(
3334+
VPlan &Plan, ElementCount VF, std::optional<unsigned> VScaleForTuning) {
33353335
VPBasicBlock *MiddleVPBB = Plan.getMiddleBlock();
33363336
auto *MiddleTerm =
33373337
dyn_cast_or_null<VPInstruction>(MiddleVPBB->getTerminator());
@@ -3343,6 +3343,8 @@ void VPlanTransforms::addBranchWeightToMiddleTerminator(VPlan &Plan,
33433343
"must have a BranchOnCond");
33443344
// Assume that `TripCount % VectorStep ` is equally distributed.
33453345
unsigned VectorStep = Plan.getUF() * VF.getKnownMinValue();
3346+
if (VF.isScalable() && VScaleForTuning.has_value())
3347+
VectorStep *= *VScaleForTuning;
33463348
assert(VectorStep > 0 && "trip count should not be zero");
33473349
MDBuilder MDB(Plan.getScalarHeader()->getIRBasicBlock()->getContext());
33483350
MDNode *BranchWeights =

llvm/lib/Transforms/Vectorize/VPlanTransforms.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,9 @@ struct VPlanTransforms {
238238

239239
/// Add branch weight metadata, if the \p Plan's middle block is terminated by
240240
/// a BranchOnCond recipe.
241-
static void addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF);
241+
static void
242+
addBranchWeightToMiddleTerminator(VPlan &Plan, ElementCount VF,
243+
std::optional<unsigned> VScaleForTuning);
242244
};
243245

244246
} // namespace llvm

llvm/test/Transforms/LoopVectorize/AArch64/check-prof-info.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ for.cond.cleanup: ; preds = %for.body
9292
; CHECK-V1-IC1: [[LOOP1]] = distinct !{[[LOOP1]], [[META2:![0-9]+]], [[META3:![0-9]+]]}
9393
; CHECK-V1-IC1: [[META2]] = !{!"llvm.loop.isvectorized", i32 1}
9494
; CHECK-V1-IC1: [[META3]] = !{!"llvm.loop.unroll.runtime.disable"}
95-
; CHECK-V1-IC1: [[PROF4]] = !{!"branch_weights", i32 1, i32 3}
95+
; CHECK-V1-IC1: [[PROF4]] = !{!"branch_weights", i32 1, i32 7}
9696
; CHECK-V1-IC1: [[PROF5]] = !{!"branch_weights", i32 0, i32 0}
9797
; CHECK-V1-IC1: [[LOOP6]] = distinct !{[[LOOP6]], [[META3]], [[META2]]}
9898
;.

0 commit comments

Comments
 (0)