Skip to content

Commit 1a5f4e4

Browse files
committed
[LV] Bundle sub reductions into VPExpressionRecipe
This PR bundles sub reductions into the VPExpressionRecipe class and adjusts the cost functions to take the negation into account.
1 parent 39f3dab commit 1a5f4e4

File tree

14 files changed

+236
-27
lines changed

14 files changed

+236
-27
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1645,8 +1645,10 @@ class TargetTransformInfo {
16451645
/// extensions. This is the cost of as:
16461646
/// ResTy vecreduce.add(mul (A, B)).
16471647
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
1648+
/// The multiply can optionally be negated, which signifies that it is a sub
1649+
/// reduction.
16481650
LLVM_ABI InstructionCost getMulAccReductionCost(
1649-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
1651+
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
16501652
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;
16511653

16521654
/// Calculate the cost of an extended reduction pattern, similar to

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -960,7 +960,7 @@ class TargetTransformInfoImplBase {
960960

961961
virtual InstructionCost
962962
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
963-
TTI::TargetCostKind CostKind) const {
963+
bool Negated, TTI::TargetCostKind CostKind) const {
964964
return 1;
965965
}
966966

llvm/include/llvm/CodeGen/BasicTTIImpl.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3116,7 +3116,10 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
31163116

31173117
InstructionCost
31183118
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
3119+
bool Negated,
31193120
TTI::TargetCostKind CostKind) const override {
3121+
if (Negated)
3122+
return InstructionCost::getInvalid(CostKind);
31203123
// Without any native support, this is equivalent to the cost of
31213124
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
31223125
// vecreduce.add(mul(A, B)).

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1274,9 +1274,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
12741274
}
12751275

12761276
InstructionCost TargetTransformInfo::getMulAccReductionCost(
1277-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
1277+
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
12781278
TTI::TargetCostKind CostKind) const {
1279-
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
1279+
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, Negated,
1280+
CostKind);
12801281
}
12811282

12821283
InstructionCost

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5316,8 +5316,10 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
53165316

53175317
InstructionCost
53185318
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
5319-
VectorType *VecTy,
5319+
VectorType *VecTy, bool Negated,
53205320
TTI::TargetCostKind CostKind) const {
5321+
if (Negated)
5322+
return InstructionCost::getInvalid(CostKind);
53215323
EVT VecVT = TLI->getValueType(DL, VecTy);
53225324
EVT ResVT = TLI->getValueType(DL, ResTy);
53235325

@@ -5332,7 +5334,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
53325334
return LT.first + 2;
53335335
}
53345336

5335-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
5337+
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, Negated,
5338+
CostKind);
53365339
}
53375340

53385341
InstructionCost

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
447447
TTI::TargetCostKind CostKind) const override;
448448

449449
InstructionCost getMulAccReductionCost(
450-
bool IsUnsigned, Type *ResTy, VectorType *Ty,
450+
bool IsUnsigned, Type *ResTy, VectorType *Ty, bool Negated,
451451
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;
452452

453453
InstructionCost

llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1884,8 +1884,10 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
18841884

18851885
InstructionCost
18861886
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
1887-
VectorType *ValTy,
1887+
VectorType *ValTy, bool Negated,
18881888
TTI::TargetCostKind CostKind) const {
1889+
if (Negated)
1890+
return InstructionCost::getInvalid(CostKind);
18891891
EVT ValVT = TLI->getValueType(DL, ValTy);
18901892
EVT ResVT = TLI->getValueType(DL, ResTy);
18911893

@@ -1906,7 +1908,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
19061908
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
19071909
}
19081910

1909-
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
1911+
return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, Negated,
1912+
CostKind);
19101913
}
19111914

19121915
InstructionCost

llvm/lib/Target/ARM/ARMTargetTransformInfo.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
299299
TTI::TargetCostKind CostKind) const override;
300300
InstructionCost
301301
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
302+
bool Negated,
302303
TTI::TargetCostKind CostKind) const override;
303304

304305
InstructionCost

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5538,7 +5538,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
55385538
TTI::CastContextHint::None, CostKind, RedOp);
55395539

55405540
InstructionCost RedCost = TTI.getMulAccReductionCost(
5541-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
5541+
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
55425542

55435543
if (RedCost.isValid() &&
55445544
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
@@ -5583,7 +5583,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
55835583
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
55845584

55855585
InstructionCost RedCost = TTI.getMulAccReductionCost(
5586-
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
5586+
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, false, CostKind);
55875587
InstructionCost ExtraExtCost = 0;
55885588
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
55895589
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
@@ -5602,7 +5602,7 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
56025602
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);
56035603

56045604
InstructionCost RedCost = TTI.getMulAccReductionCost(
5605-
true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
5605+
true, RdxDesc.getRecurrenceType(), VectorTy, false, CostKind);
56065606

56075607
if (RedCost.isValid() && RedCost < MulCost + BaseCost)
56085608
return I == RetI ? RedCost : 0;

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2757,6 +2757,12 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
27572757
/// vector operands, performing a reduction.add on the result, and adding
27582758
/// the scalar result to a chain.
27592759
MulAccReduction,
2760+
/// Represent an inloop multiply-accumulate reduction, multiplying the
2761+
/// extended vector operands, negating the multiplication, performing a
2762+
/// reduction.add
2763+
/// on the result, and adding
2764+
/// the scalar result to a chain.
2765+
ExtNegatedMulAccReduction,
27602766
};
27612767

27622768
/// Type of the expression.
@@ -2780,6 +2786,11 @@ class VPExpressionRecipe : public VPSingleDefRecipe {
27802786
VPWidenRecipe *Mul, VPReductionRecipe *Red)
27812787
: VPExpressionRecipe(ExpressionTypes::ExtMulAccReduction,
27822788
{Ext0, Ext1, Mul, Red}) {}
2789+
VPExpressionRecipe(VPWidenCastRecipe *Ext0, VPWidenCastRecipe *Ext1,
2790+
VPWidenRecipe *Mul, VPWidenRecipe *Sub,
2791+
VPReductionRecipe *Red)
2792+
: VPExpressionRecipe(ExpressionTypes::ExtNegatedMulAccReduction,
2793+
{Ext0, Ext1, Mul, Sub, Red}) {}
27832794

27842795
~VPExpressionRecipe() override {
27852796
for (auto *R : reverse(ExpressionRecipes))

0 commit comments

Comments
 (0)