Skip to content

Commit 92e8dbd

Browse files
committed
[LV] Bundle partial reductions inside VPExpressionRecipe
This PR bundles partial reductions inside the VPExpressionRecipe class. Depends on llvm#147255 .
1 parent 1a5f4e4 commit 92e8dbd

16 files changed

+617
-554
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,8 @@ class TargetTransformInfo {
223223
/// Get the kind of extension that an instruction represents.
224224
LLVM_ABI static PartialReductionExtendKind
225225
getPartialReductionExtendKind(Instruction *I);
226+
LLVM_ABI static PartialReductionExtendKind
227+
getPartialReductionExtendKind(Instruction::CastOps CastOpc);
226228

227229
/// Construct a TTI object using a type implementing the \c Concept
228230
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1001,13 +1001,24 @@ InstructionCost TargetTransformInfo::getShuffleCost(
10011001

10021002
TargetTransformInfo::PartialReductionExtendKind
10031003
TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
1004-
if (isa<SExtInst>(I))
1005-
return PR_SignExtend;
1006-
if (isa<ZExtInst>(I))
1007-
return PR_ZeroExtend;
1004+
if (auto *Cast = dyn_cast<CastInst>(I))
1005+
return getPartialReductionExtendKind(Cast->getOpcode());
10081006
return PR_None;
10091007
}
10101008

1009+
TargetTransformInfo::PartialReductionExtendKind
1010+
TargetTransformInfo::getPartialReductionExtendKind(
1011+
Instruction::CastOps CastOpc) {
1012+
switch (CastOpc) {
1013+
case Instruction::CastOps::ZExt:
1014+
return PR_ZeroExtend;
1015+
case Instruction::CastOps::SExt:
1016+
return PR_SignExtend;
1017+
default:
1018+
return PR_None;
1019+
}
1020+
}
1021+
10111022
TTI::CastContextHint
10121023
TargetTransformInfo::getCastContextHint(const Instruction *I) {
10131024
if (!I)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5294,7 +5294,7 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
52945294
EVT ResVT = TLI->getValueType(DL, ResTy);
52955295

52965296
if (Opcode == Instruction::Add && VecVT.isSimple() && ResVT.isSimple() &&
5297-
VecVT.getSizeInBits() >= 64) {
5297+
VecVT.isFixedLengthVector() && VecVT.getSizeInBits() >= 64) {
52985298
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);
52995299

53005300
// The legal cases are:

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2470,7 +2470,8 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
24702470

24712471
static inline bool classof(const VPRecipeBase *R) {
24722472
return R->getVPDefID() == VPRecipeBase::VPReductionSC ||
2473-
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC;
2473+
R->getVPDefID() == VPRecipeBase::VPReductionEVLSC ||
2474+
R->getVPDefID() == VPRecipeBase::VPPartialReductionSC;
24742475
}
24752476

24762477
static inline bool classof(const VPUser *U) {
@@ -2532,7 +2533,10 @@ class VPPartialReductionRecipe : public VPReductionRecipe {
25322533
Opcode(Opcode), VFScaleFactor(ScaleFactor) {
25332534
[[maybe_unused]] auto *AccumulatorRecipe =
25342535
getChainOp()->getDefiningRecipe();
2535-
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
2536+
// When cloning as part of a VPExpressionRecipe, the chain op could have
2537+
// been removed from the plan and so doesn't have a defining recipe.
2538+
assert((!AccumulatorRecipe ||
2539+
isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
25362540
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
25372541
"Unexpected operand order for partial reduction recipe");
25382542
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@ bool VPRecipeBase::mayHaveSideEffects() const {
164164
return cast<VPWidenIntrinsicRecipe>(this)->mayHaveSideEffects();
165165
case VPBlendSC:
166166
case VPReductionEVLSC:
167+
case VPPartialReductionSC:
167168
case VPReductionSC:
168169
case VPScalarIVStepsSC:
169170
case VPVectorPointerSC:
@@ -2678,6 +2679,23 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
26782679
case ExpressionTypes::ExtNegatedMulAccReduction:
26792680
case ExpressionTypes::ExtMulAccReduction: {
26802681
bool Negated = ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction;
2682+
if (isa<VPPartialReductionRecipe>(ExpressionRecipes.back())) {
2683+
auto *Ext0R = cast<VPWidenCastRecipe>(ExpressionRecipes[0]);
2684+
auto *Ext1R = cast<VPWidenCastRecipe>(ExpressionRecipes[1]);
2685+
auto *Mul = cast<VPWidenRecipe>(ExpressionRecipes[2]);
2686+
unsigned Opcode =
2687+
ExpressionType == ExpressionTypes::ExtNegatedMulAccReduction
2688+
? Instruction::Sub
2689+
: Instruction::Add;
2690+
return Ctx.TTI.getPartialReductionCost(
2691+
Opcode, Ctx.Types.inferScalarType(getOperand(0)),
2692+
Ctx.Types.inferScalarType(getOperand(1)), RedTy, VF,
2693+
TargetTransformInfo::getPartialReductionExtendKind(
2694+
Ext0R->getOpcode()),
2695+
TargetTransformInfo::getPartialReductionExtendKind(
2696+
Ext1R->getOpcode()),
2697+
Mul->getOpcode(), Ctx.CostKind);
2698+
}
26812699
return Ctx.TTI.getMulAccReductionCost(
26822700
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
26832701
Instruction::ZExt,
@@ -2710,6 +2728,7 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
27102728
O << " = ";
27112729
auto *Red = cast<VPReductionRecipe>(ExpressionRecipes.back());
27122730
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
2731+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
27132732

27142733
switch (ExpressionType) {
27152734
case ExpressionTypes::ExtendedReduction: {
@@ -2732,6 +2751,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
27322751
case ExpressionTypes::ExtNegatedMulAccReduction: {
27332752
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
27342753
O << " + ";
2754+
if (IsPartialReduction)
2755+
O << "partial.";
27352756
O << "reduce."
27362757
<< Instruction::getOpcodeName(
27372758
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))
@@ -2758,6 +2779,8 @@ void VPExpressionRecipe::print(raw_ostream &O, const Twine &Indent,
27582779
case ExpressionTypes::ExtMulAccReduction: {
27592780
getOperand(getNumOperands() - 1)->printAsOperand(O, SlotTracker);
27602781
O << " + ";
2782+
if (IsPartialReduction)
2783+
O << "partial.";
27612784
O << "reduce."
27622785
<< Instruction::getOpcodeName(
27632786
RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind()))

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2899,6 +2899,7 @@ static VPExpressionRecipe *
28992899
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
29002900
VPCostContext &Ctx, VFRange &Range) {
29012901
using namespace VPlanPatternMatch;
2902+
bool IsPartialReduction = isa<VPPartialReductionRecipe>(Red);
29022903

29032904
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
29042905
if (Opcode != Instruction::Add)
@@ -2955,12 +2956,14 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
29552956

29562957
// Match reduce.add(mul(ext, ext)).
29572958
if (RecipeA && RecipeB &&
2958-
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B) &&
2959+
(RecipeA->getOpcode() == RecipeB->getOpcode() || A == B ||
2960+
IsPartialReduction) &&
29592961
match(RecipeA, m_ZExtOrSExt(m_VPValue())) &&
29602962
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
2961-
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
2962-
Instruction::CastOps::ZExt,
2963-
MulR, RecipeA, RecipeB, nullptr, Sub)) {
2963+
(IsPartialReduction ||
2964+
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
2965+
Instruction::CastOps::ZExt,
2966+
MulR, RecipeA, RecipeB, nullptr, Sub))) {
29642967
if (Sub)
29652968
return new VPExpressionRecipe(
29662969
RecipeA, RecipeB, MulR,

0 commit comments

Comments
 (0)