Skip to content

Commit 61d3ad9

Browse files
authored
[SCEVPatternMatch] Introduce m_scev_AffineAddRec (llvm#140377)
Introduce m_scev_AffineAddRec to match affine AddRecs, a class_match for SCEVConstant, and demonstrate their utility in LSR and SCEV. While at it, rename m_Specific to m_scev_Specific for clarity.
1 parent c28d6c2 commit 61d3ad9

File tree

3 files changed

+72
-80
lines changed

3 files changed

+72
-80
lines changed

llvm/include/llvm/Analysis/ScalarEvolutionPatternMatch.h

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ template <typename Class> struct class_match {
6161
};
6262

6363
inline class_match<const SCEV> m_SCEV() { return class_match<const SCEV>(); }
64+
inline class_match<const SCEVConstant> m_SCEVConstant() {
65+
return class_match<const SCEVConstant>();
66+
}
6467

6568
template <typename Class> struct bind_ty {
6669
Class *&VR;
@@ -95,7 +98,7 @@ struct specificscev_ty {
9598
};
9699

97100
/// Match if we have a specific specified SCEV.
98-
inline specificscev_ty m_Specific(const SCEV *S) { return S; }
101+
inline specificscev_ty m_scev_Specific(const SCEV *S) { return S; }
99102

100103
struct is_specific_cst {
101104
uint64_t CV;
@@ -192,6 +195,12 @@ inline SCEVBinaryExpr_match<SCEVUDivExpr, Op0_t, Op1_t>
192195
m_scev_UDiv(const Op0_t &Op0, const Op1_t &Op1) {
193196
return m_scev_Binary<SCEVUDivExpr>(Op0, Op1);
194197
}
198+
199+
template <typename Op0_t, typename Op1_t>
200+
inline SCEVBinaryExpr_match<SCEVAddRecExpr, Op0_t, Op1_t>
201+
m_scev_AffineAddRec(const Op0_t &Op0, const Op1_t &Op1) {
202+
return m_scev_Binary<SCEVAddRecExpr>(Op0, Op1);
203+
}
195204
} // namespace SCEVPatternMatch
196205
} // namespace llvm
197206

llvm/lib/Analysis/ScalarEvolution.cpp

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -12480,26 +12480,21 @@ static bool IsKnownPredicateViaAddRecStart(ScalarEvolution &SE,
1248012480
if (!ICmpInst::isRelational(Pred))
1248112481
return false;
1248212482

12483-
const SCEVAddRecExpr *LAR = dyn_cast<SCEVAddRecExpr>(LHS);
12484-
if (!LAR)
12485-
return false;
12486-
const SCEVAddRecExpr *RAR = dyn_cast<SCEVAddRecExpr>(RHS);
12487-
if (!RAR)
12483+
const SCEV *LStart, *RStart, *Step;
12484+
if (!match(LHS, m_scev_AffineAddRec(m_SCEV(LStart), m_SCEV(Step))) ||
12485+
!match(RHS, m_scev_AffineAddRec(m_SCEV(RStart), m_scev_Specific(Step))))
1248812486
return false;
12487+
const SCEVAddRecExpr *LAR = cast<SCEVAddRecExpr>(LHS);
12488+
const SCEVAddRecExpr *RAR = cast<SCEVAddRecExpr>(RHS);
1248912489
if (LAR->getLoop() != RAR->getLoop())
1249012490
return false;
12491-
if (!LAR->isAffine() || !RAR->isAffine())
12492-
return false;
12493-
12494-
if (LAR->getStepRecurrence(SE) != RAR->getStepRecurrence(SE))
12495-
return false;
1249612491

1249712492
SCEV::NoWrapFlags NW = ICmpInst::isSigned(Pred) ?
1249812493
SCEV::FlagNSW : SCEV::FlagNUW;
1249912494
if (!LAR->getNoWrapFlags(NW) || !RAR->getNoWrapFlags(NW))
1250012495
return false;
1250112496

12502-
return SE.isKnownPredicate(Pred, LAR->getStart(), RAR->getStart());
12497+
return SE.isKnownPredicate(Pred, LStart, RStart);
1250312498
}
1250412499

1250512500
/// Is LHS `Pred` RHS true on the virtue of LHS or RHS being a Min or Max
@@ -12716,15 +12711,15 @@ static bool isKnownPredicateExtendIdiom(CmpPredicate Pred, const SCEV *LHS,
1271612711
case ICmpInst::ICMP_SLE: {
1271712712
// If operand >=s 0 then ZExt == SExt. If operand <s 0 then SExt <s ZExt.
1271812713
return match(LHS, m_scev_SExt(m_SCEV(Op))) &&
12719-
match(RHS, m_scev_ZExt(m_Specific(Op)));
12714+
match(RHS, m_scev_ZExt(m_scev_Specific(Op)));
1272012715
}
1272112716
case ICmpInst::ICMP_UGE:
1272212717
std::swap(LHS, RHS);
1272312718
[[fallthrough]];
1272412719
case ICmpInst::ICMP_ULE: {
1272512720
// If operand >=u 0 then ZExt == SExt. If operand <u 0 then ZExt <u SExt.
1272612721
return match(LHS, m_scev_ZExt(m_SCEV(Op))) &&
12727-
match(RHS, m_scev_SExt(m_Specific(Op)));
12722+
match(RHS, m_scev_SExt(m_scev_Specific(Op)));
1272812723
}
1272912724
default:
1273012725
return false;

llvm/lib/Transforms/Scalar/LoopStrengthReduce.cpp

Lines changed: 54 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,11 @@
7777
#include "llvm/Analysis/ScalarEvolution.h"
7878
#include "llvm/Analysis/ScalarEvolutionExpressions.h"
7979
#include "llvm/Analysis/ScalarEvolutionNormalization.h"
80+
#include "llvm/Analysis/ScalarEvolutionPatternMatch.h"
8081
#include "llvm/Analysis/TargetLibraryInfo.h"
8182
#include "llvm/Analysis/TargetTransformInfo.h"
8283
#include "llvm/Analysis/ValueTracking.h"
8384
#include "llvm/BinaryFormat/Dwarf.h"
84-
#include "llvm/Config/llvm-config.h"
8585
#include "llvm/IR/BasicBlock.h"
8686
#include "llvm/IR/Constant.h"
8787
#include "llvm/IR/Constants.h"
@@ -128,6 +128,7 @@
128128
#include <utility>
129129

130130
using namespace llvm;
131+
using namespace SCEVPatternMatch;
131132

132133
#define DEBUG_TYPE "loop-reduce"
133134

@@ -556,16 +557,17 @@ static void DoInitialMatch(const SCEV *S, Loop *L,
556557
}
557558

558559
// Look at addrec operands.
559-
if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S))
560-
if (!AR->getStart()->isZero() && AR->isAffine()) {
561-
DoInitialMatch(AR->getStart(), L, Good, Bad, SE);
562-
DoInitialMatch(SE.getAddRecExpr(SE.getConstant(AR->getType(), 0),
563-
AR->getStepRecurrence(SE),
564-
// FIXME: AR->getNoWrapFlags()
565-
AR->getLoop(), SCEV::FlagAnyWrap),
566-
L, Good, Bad, SE);
567-
return;
568-
}
560+
const SCEV *Start, *Step;
561+
if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step))) &&
562+
!Start->isZero()) {
563+
DoInitialMatch(Start, L, Good, Bad, SE);
564+
DoInitialMatch(SE.getAddRecExpr(SE.getConstant(S->getType(), 0), Step,
565+
// FIXME: AR->getNoWrapFlags()
566+
cast<SCEVAddRecExpr>(S)->getLoop(),
567+
SCEV::FlagAnyWrap),
568+
L, Good, Bad, SE);
569+
return;
570+
}
569571

570572
// Handle a multiplication by -1 (negation) if it didn't fold.
571573
if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S))
@@ -1411,22 +1413,16 @@ void Cost::RateRegister(const Formula &F, const SCEV *Reg,
14111413
unsigned LoopCost = 1;
14121414
if (TTI->isIndexedLoadLegal(TTI->MIM_PostInc, AR->getType()) ||
14131415
TTI->isIndexedStoreLegal(TTI->MIM_PostInc, AR->getType())) {
1414-
1415-
// If the step size matches the base offset, we could use pre-indexed
1416-
// addressing.
1417-
if (AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed()) {
1418-
if (auto *Step = dyn_cast<SCEVConstant>(AR->getStepRecurrence(*SE)))
1419-
if (Step->getAPInt() == F.BaseOffset.getFixedValue())
1420-
LoopCost = 0;
1421-
} else if (AMK == TTI::AMK_PostIndexed) {
1422-
const SCEV *LoopStep = AR->getStepRecurrence(*SE);
1423-
if (isa<SCEVConstant>(LoopStep)) {
1424-
const SCEV *LoopStart = AR->getStart();
1425-
if (!isa<SCEVConstant>(LoopStart) &&
1426-
SE->isLoopInvariant(LoopStart, L))
1427-
LoopCost = 0;
1428-
}
1429-
}
1416+
const SCEV *Start;
1417+
const SCEVConstant *Step;
1418+
if (match(AR, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant(Step))))
1419+
// If the step size matches the base offset, we could use pre-indexed
1420+
// addressing.
1421+
if ((AMK == TTI::AMK_PreIndexed && F.BaseOffset.isFixed() &&
1422+
Step->getAPInt() == F.BaseOffset.getFixedValue()) ||
1423+
(AMK == TTI::AMK_PostIndexed && !isa<SCEVConstant>(Start) &&
1424+
SE->isLoopInvariant(Start, L)))
1425+
LoopCost = 0;
14301426
}
14311427
C.AddRecCost += LoopCost;
14321428

@@ -2519,13 +2515,11 @@ ICmpInst *LSRInstance::OptimizeMax(ICmpInst *Cond, IVStrideUse* &CondUse) {
25192515
// Check the relevant induction variable for conformance to
25202516
// the pattern.
25212517
const SCEV *IV = SE.getSCEV(Cond->getOperand(0));
2522-
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(IV);
2523-
if (!AR || !AR->isAffine() ||
2524-
AR->getStart() != One ||
2525-
AR->getStepRecurrence(SE) != One)
2518+
if (!match(IV,
2519+
m_scev_AffineAddRec(m_scev_SpecificInt(1), m_scev_SpecificInt(1))))
25262520
return Cond;
25272521

2528-
assert(AR->getLoop() == L &&
2522+
assert(cast<SCEVAddRecExpr>(IV)->getLoop() == L &&
25292523
"Loop condition operand is an addrec in a different loop!");
25302524

25312525
// Check the right operand of the select, and remember it, as it will
@@ -3320,7 +3314,7 @@ void LSRInstance::CollectChains() {
33203314
void LSRInstance::FinalizeChain(IVChain &Chain) {
33213315
assert(!Chain.Incs.empty() && "empty IV chains are not allowed");
33223316
LLVM_DEBUG(dbgs() << "Final Chain: " << *Chain.Incs[0].UserInst << "\n");
3323-
3317+
33243318
for (const IVInc &Inc : Chain) {
33253319
LLVM_DEBUG(dbgs() << " Inc: " << *Inc.UserInst << "\n");
33263320
auto UseI = find(Inc.UserInst->operands(), Inc.IVOperand);
@@ -3823,26 +3817,27 @@ static const SCEV *CollectSubexprs(const SCEV *S, const SCEVConstant *C,
38233817
Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
38243818
}
38253819
return nullptr;
3826-
} else if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S)) {
3820+
}
3821+
const SCEV *Start, *Step;
3822+
if (match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEV(Step)))) {
38273823
// Split a non-zero base out of an addrec.
3828-
if (AR->getStart()->isZero() || !AR->isAffine())
3824+
if (Start->isZero())
38293825
return S;
38303826

3831-
const SCEV *Remainder = CollectSubexprs(AR->getStart(),
3832-
C, Ops, L, SE, Depth+1);
3827+
const SCEV *Remainder = CollectSubexprs(Start, C, Ops, L, SE, Depth + 1);
38333828
// Split the non-zero AddRec unless it is part of a nested recurrence that
38343829
// does not pertain to this loop.
3835-
if (Remainder && (AR->getLoop() == L || !isa<SCEVAddRecExpr>(Remainder))) {
3830+
if (Remainder && (cast<SCEVAddRecExpr>(S)->getLoop() == L ||
3831+
!isa<SCEVAddRecExpr>(Remainder))) {
38363832
Ops.push_back(C ? SE.getMulExpr(C, Remainder) : Remainder);
38373833
Remainder = nullptr;
38383834
}
3839-
if (Remainder != AR->getStart()) {
3835+
if (Remainder != Start) {
38403836
if (!Remainder)
3841-
Remainder = SE.getConstant(AR->getType(), 0);
3842-
return SE.getAddRecExpr(Remainder,
3843-
AR->getStepRecurrence(SE),
3844-
AR->getLoop(),
3845-
//FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
3837+
Remainder = SE.getConstant(S->getType(), 0);
3838+
return SE.getAddRecExpr(Remainder, Step,
3839+
cast<SCEVAddRecExpr>(S)->getLoop(),
3840+
// FIXME: AR->getNoWrapFlags(SCEV::FlagNW)
38463841
SCEV::FlagAnyWrap);
38473842
}
38483843
} else if (const SCEVMulExpr *Mul = dyn_cast<SCEVMulExpr>(S)) {
@@ -3870,17 +3865,13 @@ static bool mayUsePostIncMode(const TargetTransformInfo &TTI,
38703865
if (LU.Kind != LSRUse::Address ||
38713866
!LU.AccessTy.getType()->isIntOrIntVectorTy())
38723867
return false;
3873-
const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(S);
3874-
if (!AR)
3875-
return false;
3876-
const SCEV *LoopStep = AR->getStepRecurrence(SE);
3877-
if (!isa<SCEVConstant>(LoopStep))
3868+
const SCEV *Start;
3869+
if (!match(S, m_scev_AffineAddRec(m_SCEV(Start), m_SCEVConstant())))
38783870
return false;
38793871
// Check if a post-indexed load/store can be used.
3880-
if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, AR->getType()) ||
3881-
TTI.isIndexedStoreLegal(TTI.MIM_PostInc, AR->getType())) {
3882-
const SCEV *LoopStart = AR->getStart();
3883-
if (!isa<SCEVConstant>(LoopStart) && SE.isLoopInvariant(LoopStart, L))
3872+
if (TTI.isIndexedLoadLegal(TTI.MIM_PostInc, S->getType()) ||
3873+
TTI.isIndexedStoreLegal(TTI.MIM_PostInc, S->getType())) {
3874+
if (!isa<SCEVConstant>(Start) && SE.isLoopInvariant(Start, L))
38843875
return true;
38853876
}
38863877
return false;
@@ -4139,18 +4130,15 @@ void LSRInstance::GenerateConstantOffsetsImpl(
41394130
// base pointer for each iteration of the loop, resulting in no extra add/sub
41404131
// instructions for pointer updating.
41414132
if (AMK == TTI::AMK_PreIndexed && LU.Kind == LSRUse::Address) {
4142-
if (auto *GAR = dyn_cast<SCEVAddRecExpr>(G)) {
4143-
if (auto *StepRec =
4144-
dyn_cast<SCEVConstant>(GAR->getStepRecurrence(SE))) {
4145-
const APInt &StepInt = StepRec->getAPInt();
4146-
int64_t Step = StepInt.isNegative() ?
4147-
StepInt.getSExtValue() : StepInt.getZExtValue();
4148-
4149-
for (Immediate Offset : Worklist) {
4150-
if (Offset.isFixed()) {
4151-
Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
4152-
GenerateOffset(G, Offset);
4153-
}
4133+
const APInt *StepInt;
4134+
if (match(G, m_scev_AffineAddRec(m_SCEV(), m_scev_APInt(StepInt)))) {
4135+
int64_t Step = StepInt->isNegative() ? StepInt->getSExtValue()
4136+
: StepInt->getZExtValue();
4137+
4138+
for (Immediate Offset : Worklist) {
4139+
if (Offset.isFixed()) {
4140+
Offset = Immediate::getFixed(Offset.getFixedValue() - Step);
4141+
GenerateOffset(G, Offset);
41544142
}
41554143
}
41564144
}
@@ -6621,7 +6609,7 @@ struct SCEVDbgValueBuilder {
66216609
if (Op.getOp() != dwarf::DW_OP_LLVM_arg) {
66226610
Op.appendToVector(DestExpr);
66236611
continue;
6624-
}
6612+
}
66256613

66266614
DestExpr.push_back(dwarf::DW_OP_LLVM_arg);
66276615
// `DW_OP_LLVM_arg n` represents the nth LocationOp in this SCEV,

0 commit comments

Comments
 (0)