Skip to content

Commit d544a89

Browse files
author
Sjoerd Meijer
committed
[LoopFlatten] Update MemorySSA state
I would like to move LoopFlatten from LoopPass Manager LPM2 to LPM1 (D116612), but that is a LPM that is using MemorySSA and so LoopFlatten needs to preserve MemorySSA and this adds that. More specifically, LoopFlatten restructures the CFG and with this change the MSSA state is updated accordingly, where we also update the DomTree. LoopFlatten doesn't rewrite/optimise/delete load or store instructions, so I have not added any MSSA updates for that. Differential Revision: https://reviews.llvm.org/D116660
1 parent 93e8cd2 commit d544a89

File tree

1 file changed

+55
-23
lines changed

1 file changed

+55
-23
lines changed

llvm/lib/Transforms/Scalar/LoopFlatten.cpp

Lines changed: 55 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
#include "llvm/ADT/Statistic.h"
3232
#include "llvm/Analysis/AssumptionCache.h"
3333
#include "llvm/Analysis/LoopInfo.h"
34+
#include "llvm/Analysis/MemorySSAUpdater.h"
3435
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
3536
#include "llvm/Analysis/ScalarEvolution.h"
3637
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -70,8 +71,7 @@ static cl::opt<bool>
7071
"trip counts will never overflow"));
7172

7273
static cl::opt<bool>
73-
WidenIV("loop-flatten-widen-iv", cl::Hidden,
74-
cl::init(true),
74+
WidenIV("loop-flatten-widen-iv", cl::Hidden, cl::init(true),
7575
cl::desc("Widen the loop induction variables, if possible, so "
7676
"overflow checks won't reject flattening"));
7777

@@ -100,7 +100,7 @@ struct FlattenInfo {
100100
PHINode *NarrowInnerInductionPHI = nullptr;
101101
PHINode *NarrowOuterInductionPHI = nullptr;
102102

103-
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
103+
FlattenInfo(Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
104104

105105
bool isNarrowInductionPhi(PHINode *Phi) {
106106
// This can't be the narrow phi if we haven't widened the IV first.
@@ -207,7 +207,7 @@ static bool findLoopComponents(
207207
// nothing obvious in the surrounding code when handles the overflow case.
208208
// FIXME: audit code to establish whether there's a latent bug here.
209209
const SCEV *SCEVTripCount =
210-
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
210+
SE->getTripCountFromExitCount(BackedgeTakenCount, false);
211211
const SCEV *SCEVRHS = SE->getSCEV(RHS);
212212
if (SCEVRHS == SCEVTripCount)
213213
return setLoopComponents(RHS, TripCount, Increment, IterationInstructions);
@@ -611,7 +611,8 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
611611

612612
static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
613613
ScalarEvolution *SE, AssumptionCache *AC,
614-
const TargetTransformInfo *TTI, LPMUpdater *U) {
614+
const TargetTransformInfo *TTI, LPMUpdater *U,
615+
MemorySSAUpdater *MSSAU) {
615616
Function *F = FI.OuterLoop->getHeader()->getParent();
616617
LLVM_DEBUG(dbgs() << "Checks all passed, doing the transformation\n");
617618
{
@@ -647,7 +648,11 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
647648
BasicBlock *InnerExitingBlock = FI.InnerLoop->getExitingBlock();
648649
InnerExitingBlock->getTerminator()->eraseFromParent();
649650
BranchInst::Create(InnerExitBlock, InnerExitingBlock);
651+
652+
// Update the DomTree and MemorySSA.
650653
DT->deleteEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
654+
if (MSSAU)
655+
MSSAU->removeEdge(InnerExitingBlock, FI.InnerLoop->getHeader());
651656

652657
// Replace all uses of the polynomial calculated from the two induction
653658
// variables with the one new one.
@@ -658,8 +663,8 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
658663
OuterValue = Builder.CreateTrunc(FI.OuterInductionPHI, V->getType(),
659664
"flatten.trunciv");
660665

661-
LLVM_DEBUG(dbgs() << "Replacing: "; V->dump();
662-
dbgs() << "with: "; OuterValue->dump());
666+
LLVM_DEBUG(dbgs() << "Replacing: "; V->dump(); dbgs() << "with: ";
667+
OuterValue->dump());
663668
V->replaceAllUsesWith(OuterValue);
664669
}
665670

@@ -698,7 +703,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
698703
// (OuterTripCount * InnerTripCount) as the new trip count is safe.
699704
if (InnerType != OuterType ||
700705
InnerType->getScalarSizeInBits() >= MaxLegalSize ||
701-
MaxLegalType->getScalarSizeInBits() < InnerType->getScalarSizeInBits() * 2) {
706+
MaxLegalType->getScalarSizeInBits() <
707+
InnerType->getScalarSizeInBits() * 2) {
702708
LLVM_DEBUG(dbgs() << "Can't widen the IV\n");
703709
return false;
704710
}
@@ -708,10 +714,10 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
708714
unsigned ElimExt = 0;
709715
unsigned Widened = 0;
710716

711-
auto CreateWideIV = [&] (WideIVInfo WideIV, bool &Deleted) -> bool {
712-
PHINode *WidePhi = createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts,
713-
ElimExt, Widened, true /* HasGuards */,
714-
true /* UsePostIncrementRanges */);
717+
auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {
718+
PHINode *WidePhi =
719+
createWideIV(WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
720+
true /* HasGuards */, true /* UsePostIncrementRanges */);
715721
if (!WidePhi)
716722
return false;
717723
LLVM_DEBUG(dbgs() << "Created wide phi: "; WidePhi->dump());
@@ -721,14 +727,14 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
721727
};
722728

723729
bool Deleted;
724-
if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false }, Deleted))
730+
if (!CreateWideIV({FI.InnerInductionPHI, MaxLegalType, false}, Deleted))
725731
return false;
726732
// Add the narrow phi to list, so that it will be adjusted later when the
727733
// the transformation is performed.
728734
if (!Deleted)
729735
FI.InnerPHIsToTransform.insert(FI.InnerInductionPHI);
730736

731-
if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false }, Deleted))
737+
if (!CreateWideIV({FI.OuterInductionPHI, MaxLegalType, false}, Deleted))
732738
return false;
733739

734740
assert(Widened && "Widened IV expected");
@@ -744,7 +750,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
744750

745751
static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
746752
ScalarEvolution *SE, AssumptionCache *AC,
747-
const TargetTransformInfo *TTI, LPMUpdater *U) {
753+
const TargetTransformInfo *TTI, LPMUpdater *U,
754+
MemorySSAUpdater *MSSAU) {
748755
LLVM_DEBUG(
749756
dbgs() << "Loop flattening running on outer loop "
750757
<< FI.OuterLoop->getHeader()->getName() << " and inner loop "
@@ -773,7 +780,7 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
773780

774781
// If we have widened and can perform the transformation, do that here.
775782
if (CanFlatten)
776-
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
783+
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
777784

778785
// Otherwise, if we haven't widened the IV, check if the new iteration
779786
// variable might overflow. In this case, we need to version the loop, and
@@ -791,18 +798,19 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
791798
}
792799

793800
LLVM_DEBUG(dbgs() << "Multiply cannot overflow, modifying loop in-place\n");
794-
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
801+
return DoFlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
795802
}
796803

797804
bool Flatten(LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
798-
AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U) {
805+
AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
806+
MemorySSAUpdater *MSSAU) {
799807
bool Changed = false;
800808
for (Loop *InnerLoop : LN.getLoops()) {
801809
auto *OuterLoop = InnerLoop->getParentLoop();
802810
if (!OuterLoop)
803811
continue;
804812
FlattenInfo FI(OuterLoop, InnerLoop);
805-
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U);
813+
Changed |= FlattenLoopPair(FI, DT, LI, SE, AC, TTI, U, MSSAU);
806814
}
807815
return Changed;
808816
}
@@ -813,16 +821,30 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
813821

814822
bool Changed = false;
815823

824+
Optional<MemorySSAUpdater> MSSAU;
825+
if (AR.MSSA) {
826+
MSSAU = MemorySSAUpdater(AR.MSSA);
827+
if (VerifyMemorySSA)
828+
AR.MSSA->verifyMemorySSA();
829+
}
830+
816831
// The loop flattening pass requires loops to be
817832
// in simplified form, and also needs LCSSA. Running
818833
// this pass will simplify all loops that contain inner loops,
819834
// regardless of whether anything ends up being flattened.
820-
Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U);
835+
Changed |= Flatten(LN, &AR.DT, &AR.LI, &AR.SE, &AR.AC, &AR.TTI, &U,
836+
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr);
821837

822838
if (!Changed)
823839
return PreservedAnalyses::all();
824840

825-
return getLoopPassPreservedAnalyses();
841+
if (AR.MSSA && VerifyMemorySSA)
842+
AR.MSSA->verifyMemorySSA();
843+
844+
auto PA = getLoopPassPreservedAnalyses();
845+
if (AR.MSSA)
846+
PA.preserve<MemorySSAAnalysis>();
847+
return PA;
826848
}
827849

828850
namespace {
@@ -842,6 +864,7 @@ class LoopFlattenLegacyPass : public FunctionPass {
842864
AU.addPreserved<TargetTransformInfoWrapperPass>();
843865
AU.addRequired<AssumptionCacheTracker>();
844866
AU.addPreserved<AssumptionCacheTracker>();
867+
AU.addPreserved<MemorySSAWrapperPass>();
845868
}
846869
};
847870
} // namespace
@@ -854,7 +877,9 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
854877
INITIALIZE_PASS_END(LoopFlattenLegacyPass, "loop-flatten", "Flattens loops",
855878
false, false)
856879

857-
FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass(); }
880+
FunctionPass *llvm::createLoopFlattenPass() {
881+
return new LoopFlattenLegacyPass();
882+
}
858883

859884
bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
860885
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE();
@@ -864,10 +889,17 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
864889
auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
865890
auto *TTI = &TTIP.getTTI(F);
866891
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
892+
auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
893+
894+
Optional<MemorySSAUpdater> MSSAU;
895+
if (MSSA)
896+
MSSAU = MemorySSAUpdater(&MSSA->getMSSA());
897+
867898
bool Changed = false;
868899
for (Loop *L : *LI) {
869900
auto LN = LoopNest::getLoopNest(*L, *SE);
870-
Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr);
901+
Changed |= Flatten(*LN, DT, LI, SE, AC, TTI, nullptr,
902+
MSSAU.hasValue() ? MSSAU.getPointer() : nullptr);
871903
}
872904
return Changed;
873905
}

0 commit comments

Comments
 (0)