31
31
#include " llvm/ADT/Statistic.h"
32
32
#include " llvm/Analysis/AssumptionCache.h"
33
33
#include " llvm/Analysis/LoopInfo.h"
34
+ #include " llvm/Analysis/MemorySSAUpdater.h"
34
35
#include " llvm/Analysis/OptimizationRemarkEmitter.h"
35
36
#include " llvm/Analysis/ScalarEvolution.h"
36
37
#include " llvm/Analysis/TargetTransformInfo.h"
@@ -70,8 +71,7 @@ static cl::opt<bool>
70
71
" trip counts will never overflow" ));
71
72
72
73
static cl::opt<bool >
73
- WidenIV (" loop-flatten-widen-iv" , cl::Hidden,
74
- cl::init (true ),
74
+ WidenIV (" loop-flatten-widen-iv" , cl::Hidden, cl::init(true ),
75
75
cl::desc(" Widen the loop induction variables, if possible, so "
76
76
" overflow checks won't reject flattening" ));
77
77
@@ -100,7 +100,7 @@ struct FlattenInfo {
100
100
PHINode *NarrowInnerInductionPHI = nullptr ;
101
101
PHINode *NarrowOuterInductionPHI = nullptr ;
102
102
103
- FlattenInfo (Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL) {};
103
+ FlattenInfo (Loop *OL, Loop *IL) : OuterLoop(OL), InnerLoop(IL){};
104
104
105
105
bool isNarrowInductionPhi (PHINode *Phi) {
106
106
// This can't be the narrow phi if we haven't widened the IV first.
@@ -207,7 +207,7 @@ static bool findLoopComponents(
207
207
// nothing obvious in the surrounding code when handles the overflow case.
208
208
// FIXME: audit code to establish whether there's a latent bug here.
209
209
const SCEV *SCEVTripCount =
210
- SE->getTripCountFromExitCount (BackedgeTakenCount, false );
210
+ SE->getTripCountFromExitCount (BackedgeTakenCount, false );
211
211
const SCEV *SCEVRHS = SE->getSCEV (RHS);
212
212
if (SCEVRHS == SCEVTripCount)
213
213
return setLoopComponents (RHS, TripCount, Increment, IterationInstructions);
@@ -611,7 +611,8 @@ static bool CanFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
611
611
612
612
static bool DoFlattenLoopPair (FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
613
613
ScalarEvolution *SE, AssumptionCache *AC,
614
- const TargetTransformInfo *TTI, LPMUpdater *U) {
614
+ const TargetTransformInfo *TTI, LPMUpdater *U,
615
+ MemorySSAUpdater *MSSAU) {
615
616
Function *F = FI.OuterLoop ->getHeader ()->getParent ();
616
617
LLVM_DEBUG (dbgs () << " Checks all passed, doing the transformation\n " );
617
618
{
@@ -647,7 +648,11 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
647
648
BasicBlock *InnerExitingBlock = FI.InnerLoop ->getExitingBlock ();
648
649
InnerExitingBlock->getTerminator ()->eraseFromParent ();
649
650
BranchInst::Create (InnerExitBlock, InnerExitingBlock);
651
+
652
+ // Update the DomTree and MemorySSA.
650
653
DT->deleteEdge (InnerExitingBlock, FI.InnerLoop ->getHeader ());
654
+ if (MSSAU)
655
+ MSSAU->removeEdge (InnerExitingBlock, FI.InnerLoop ->getHeader ());
651
656
652
657
// Replace all uses of the polynomial calculated from the two induction
653
658
// variables with the one new one.
@@ -658,8 +663,8 @@ static bool DoFlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
658
663
OuterValue = Builder.CreateTrunc (FI.OuterInductionPHI , V->getType (),
659
664
" flatten.trunciv" );
660
665
661
- LLVM_DEBUG (dbgs () << " Replacing: " ; V->dump ();
662
- dbgs () << " with: " ; OuterValue->dump ());
666
+ LLVM_DEBUG (dbgs () << " Replacing: " ; V->dump (); dbgs () << " with: " ;
667
+ OuterValue->dump ());
663
668
V->replaceAllUsesWith (OuterValue);
664
669
}
665
670
@@ -698,7 +703,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
698
703
// (OuterTripCount * InnerTripCount) as the new trip count is safe.
699
704
if (InnerType != OuterType ||
700
705
InnerType->getScalarSizeInBits () >= MaxLegalSize ||
701
- MaxLegalType->getScalarSizeInBits () < InnerType->getScalarSizeInBits () * 2 ) {
706
+ MaxLegalType->getScalarSizeInBits () <
707
+ InnerType->getScalarSizeInBits () * 2 ) {
702
708
LLVM_DEBUG (dbgs () << " Can't widen the IV\n " );
703
709
return false ;
704
710
}
@@ -708,10 +714,10 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
708
714
unsigned ElimExt = 0 ;
709
715
unsigned Widened = 0 ;
710
716
711
- auto CreateWideIV = [&] (WideIVInfo WideIV, bool &Deleted) -> bool {
712
- PHINode *WidePhi = createWideIV (WideIV, LI, SE, Rewriter, DT, DeadInsts,
713
- ElimExt, Widened, true /* HasGuards */ ,
714
- true /* UsePostIncrementRanges */ );
717
+ auto CreateWideIV = [&](WideIVInfo WideIV, bool &Deleted) -> bool {
718
+ PHINode *WidePhi =
719
+ createWideIV (WideIV, LI, SE, Rewriter, DT, DeadInsts, ElimExt, Widened,
720
+ true /* HasGuards */ , true /* UsePostIncrementRanges */ );
715
721
if (!WidePhi)
716
722
return false ;
717
723
LLVM_DEBUG (dbgs () << " Created wide phi: " ; WidePhi->dump ());
@@ -721,14 +727,14 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
721
727
};
722
728
723
729
bool Deleted;
724
- if (!CreateWideIV ({FI.InnerInductionPHI , MaxLegalType, false }, Deleted))
730
+ if (!CreateWideIV ({FI.InnerInductionPHI , MaxLegalType, false }, Deleted))
725
731
return false ;
726
732
// Add the narrow phi to list, so that it will be adjusted later when the
727
733
// the transformation is performed.
728
734
if (!Deleted)
729
735
FI.InnerPHIsToTransform .insert (FI.InnerInductionPHI );
730
736
731
- if (!CreateWideIV ({FI.OuterInductionPHI , MaxLegalType, false }, Deleted))
737
+ if (!CreateWideIV ({FI.OuterInductionPHI , MaxLegalType, false }, Deleted))
732
738
return false ;
733
739
734
740
assert (Widened && " Widened IV expected" );
@@ -744,7 +750,8 @@ static bool CanWidenIV(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
744
750
745
751
static bool FlattenLoopPair (FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
746
752
ScalarEvolution *SE, AssumptionCache *AC,
747
- const TargetTransformInfo *TTI, LPMUpdater *U) {
753
+ const TargetTransformInfo *TTI, LPMUpdater *U,
754
+ MemorySSAUpdater *MSSAU) {
748
755
LLVM_DEBUG (
749
756
dbgs () << " Loop flattening running on outer loop "
750
757
<< FI.OuterLoop ->getHeader ()->getName () << " and inner loop "
@@ -773,7 +780,7 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
773
780
774
781
// If we have widened and can perform the transformation, do that here.
775
782
if (CanFlatten)
776
- return DoFlattenLoopPair (FI, DT, LI, SE, AC, TTI, U);
783
+ return DoFlattenLoopPair (FI, DT, LI, SE, AC, TTI, U, MSSAU );
777
784
778
785
// Otherwise, if we haven't widened the IV, check if the new iteration
779
786
// variable might overflow. In this case, we need to version the loop, and
@@ -791,18 +798,19 @@ static bool FlattenLoopPair(FlattenInfo &FI, DominatorTree *DT, LoopInfo *LI,
791
798
}
792
799
793
800
LLVM_DEBUG (dbgs () << " Multiply cannot overflow, modifying loop in-place\n " );
794
- return DoFlattenLoopPair (FI, DT, LI, SE, AC, TTI, U);
801
+ return DoFlattenLoopPair (FI, DT, LI, SE, AC, TTI, U, MSSAU );
795
802
}
796
803
797
804
bool Flatten (LoopNest &LN, DominatorTree *DT, LoopInfo *LI, ScalarEvolution *SE,
798
- AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U) {
805
+ AssumptionCache *AC, TargetTransformInfo *TTI, LPMUpdater *U,
806
+ MemorySSAUpdater *MSSAU) {
799
807
bool Changed = false ;
800
808
for (Loop *InnerLoop : LN.getLoops ()) {
801
809
auto *OuterLoop = InnerLoop->getParentLoop ();
802
810
if (!OuterLoop)
803
811
continue ;
804
812
FlattenInfo FI (OuterLoop, InnerLoop);
805
- Changed |= FlattenLoopPair (FI, DT, LI, SE, AC, TTI, U);
813
+ Changed |= FlattenLoopPair (FI, DT, LI, SE, AC, TTI, U, MSSAU );
806
814
}
807
815
return Changed;
808
816
}
@@ -813,16 +821,30 @@ PreservedAnalyses LoopFlattenPass::run(LoopNest &LN, LoopAnalysisManager &LAM,
813
821
814
822
bool Changed = false ;
815
823
824
+ Optional<MemorySSAUpdater> MSSAU;
825
+ if (AR.MSSA ) {
826
+ MSSAU = MemorySSAUpdater (AR.MSSA );
827
+ if (VerifyMemorySSA)
828
+ AR.MSSA ->verifyMemorySSA ();
829
+ }
830
+
816
831
// The loop flattening pass requires loops to be
817
832
// in simplified form, and also needs LCSSA. Running
818
833
// this pass will simplify all loops that contain inner loops,
819
834
// regardless of whether anything ends up being flattened.
820
- Changed |= Flatten (LN, &AR.DT , &AR.LI , &AR.SE , &AR.AC , &AR.TTI , &U);
835
+ Changed |= Flatten (LN, &AR.DT , &AR.LI , &AR.SE , &AR.AC , &AR.TTI , &U,
836
+ MSSAU.hasValue () ? MSSAU.getPointer () : nullptr );
821
837
822
838
if (!Changed)
823
839
return PreservedAnalyses::all ();
824
840
825
- return getLoopPassPreservedAnalyses ();
841
+ if (AR.MSSA && VerifyMemorySSA)
842
+ AR.MSSA ->verifyMemorySSA ();
843
+
844
+ auto PA = getLoopPassPreservedAnalyses ();
845
+ if (AR.MSSA )
846
+ PA.preserve <MemorySSAAnalysis>();
847
+ return PA;
826
848
}
827
849
828
850
namespace {
@@ -842,6 +864,7 @@ class LoopFlattenLegacyPass : public FunctionPass {
842
864
AU.addPreserved <TargetTransformInfoWrapperPass>();
843
865
AU.addRequired <AssumptionCacheTracker>();
844
866
AU.addPreserved <AssumptionCacheTracker>();
867
+ AU.addPreserved <MemorySSAWrapperPass>();
845
868
}
846
869
};
847
870
} // namespace
@@ -854,7 +877,9 @@ INITIALIZE_PASS_DEPENDENCY(AssumptionCacheTracker)
854
877
INITIALIZE_PASS_END(LoopFlattenLegacyPass, " loop-flatten" , " Flattens loops" ,
855
878
false , false )
856
879
857
- FunctionPass *llvm::createLoopFlattenPass() { return new LoopFlattenLegacyPass (); }
880
+ FunctionPass *llvm::createLoopFlattenPass() {
881
+ return new LoopFlattenLegacyPass ();
882
+ }
858
883
859
884
bool LoopFlattenLegacyPass::runOnFunction (Function &F) {
860
885
ScalarEvolution *SE = &getAnalysis<ScalarEvolutionWrapperPass>().getSE ();
@@ -864,10 +889,17 @@ bool LoopFlattenLegacyPass::runOnFunction(Function &F) {
864
889
auto &TTIP = getAnalysis<TargetTransformInfoWrapperPass>();
865
890
auto *TTI = &TTIP.getTTI (F);
866
891
auto *AC = &getAnalysis<AssumptionCacheTracker>().getAssumptionCache (F);
892
+ auto *MSSA = getAnalysisIfAvailable<MemorySSAWrapperPass>();
893
+
894
+ Optional<MemorySSAUpdater> MSSAU;
895
+ if (MSSA)
896
+ MSSAU = MemorySSAUpdater (&MSSA->getMSSA ());
897
+
867
898
bool Changed = false ;
868
899
for (Loop *L : *LI) {
869
900
auto LN = LoopNest::getLoopNest (*L, *SE);
870
- Changed |= Flatten (*LN, DT, LI, SE, AC, TTI, nullptr );
901
+ Changed |= Flatten (*LN, DT, LI, SE, AC, TTI, nullptr ,
902
+ MSSAU.hasValue () ? MSSAU.getPointer () : nullptr );
871
903
}
872
904
return Changed;
873
905
}
0 commit comments