diff --git a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h index a101151eed7cc..39fef921a9590 100644 --- a/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h +++ b/llvm/include/llvm/Transforms/Utils/ScalarEvolutionExpander.h @@ -530,6 +530,7 @@ class SCEVExpander : public SCEVVisitor { bool isExpandedAddRecExprPHI(PHINode *PN, Instruction *IncV, const Loop *L); + Value *tryToReuseLCSSAPhi(const SCEVAddRecExpr *S); Value *expandAddRecExprLiterally(const SCEVAddRecExpr *); PHINode *getAddRecExprPHILiterally(const SCEVAddRecExpr *Normalized, const Loop *L, Type *&TruncTy, diff --git a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp index 24fe08d6c3e4e..37df5b2b94b54 100644 --- a/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp +++ b/llvm/lib/Transforms/Utils/ScalarEvolutionExpander.cpp @@ -1223,6 +1223,24 @@ Value *SCEVExpander::expandAddRecExprLiterally(const SCEVAddRecExpr *S) { return Result; } +Value *SCEVExpander::tryToReuseLCSSAPhi(const SCEVAddRecExpr *S) { + const Loop *L = S->getLoop(); + BasicBlock *EB = L->getExitBlock(); + if (!EB || !EB->getSinglePredecessor() || + !SE.DT.dominates(EB, Builder.GetInsertBlock())) + return nullptr; + + for (auto &PN : EB->phis()) { + if (!SE.isSCEVable(PN.getType()) || PN.getType() != S->getType()) + continue; + auto *ExitV = SE.getSCEV(&PN); + if (S == ExitV) + return &PN; + } + + return nullptr; +} + Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { // In canonical mode we compute the addrec as an expression of a canonical IV // using evaluateAtIteration and expand the resulting SCEV expression. This @@ -1262,6 +1280,11 @@ Value *SCEVExpander::visitAddRecExpr(const SCEVAddRecExpr *S) { return V; } + // If S is expanded outside the defining loop, check if there is a + // matching LCSSA phi node for it. + if (Value *V = tryToReuseLCSSAPhi(S)) + return V; + // {X,+,F} --> X + {0,+,F} if (!S->getStart()->isZero()) { if (isa(S->getType())) { diff --git a/llvm/test/Transforms/LoopVectorize/reuse-lcssa-phi-scev-expansion.ll b/llvm/test/Transforms/LoopVectorize/reuse-lcssa-phi-scev-expansion.ll index 2747895f06a7b..ce4270dc4b7fa 100644 --- a/llvm/test/Transforms/LoopVectorize/reuse-lcssa-phi-scev-expansion.ll +++ b/llvm/test/Transforms/LoopVectorize/reuse-lcssa-phi-scev-expansion.ll @@ -18,11 +18,9 @@ define void @reuse_lcssa_phi_for_add_rec1(ptr %head) { ; CHECK-NEXT: [[IV_NEXT]] = add nuw i64 [[IV]], 1 ; CHECK-NEXT: br i1 [[EC_1]], label %[[PH:.*]], label %[[LOOP_1]] ; CHECK: [[PH]]: -; CHECK-NEXT: [[IV_2_LCSSA:%.*]] = phi i32 [ [[IV_2]], %[[LOOP_1]] ] ; CHECK-NEXT: [[IV_LCSSA:%.*]] = phi i64 [ [[IV]], %[[LOOP_1]] ] -; CHECK-NEXT: [[IV_2_NEXT_LCSSA:%.*]] = phi i32 [ [[IV_2_NEXT]], %[[LOOP_1]] ] +; CHECK-NEXT: [[TMP0:%.*]] = phi i32 [ [[IV_2_NEXT]], %[[LOOP_1]] ] ; CHECK-NEXT: [[SRC_2:%.*]] = tail call noalias noundef dereferenceable_or_null(8) ptr @calloc(i64 1, i64 8) -; CHECK-NEXT: [[TMP0:%.*]] = add i32 [[IV_2_LCSSA]], 1 ; CHECK-NEXT: [[SMIN:%.*]] = call i32 @llvm.smin.i32(i32 [[TMP0]], i32 1) ; CHECK-NEXT: [[TMP1:%.*]] = sub i32 [[TMP0]], [[SMIN]] ; CHECK-NEXT: [[TMP2:%.*]] = zext i32 [[TMP1]] to i64