Skip to content

Commit b963701

Browse files
authored
[TTI] Plumb CostKind through getPartialReductionCost (#144953)
Purely for the sake of being idiomatic with other TTI costing routines, no direct motivation beyond that.
1 parent dfb5cad commit b963701

11 files changed

+40
-37
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1332,8 +1332,8 @@ class TargetTransformInfo {
13321332
LLVM_ABI InstructionCost getPartialReductionCost(
13331333
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
13341334
ElementCount VF, PartialReductionExtendKind OpAExtend,
1335-
PartialReductionExtendKind OpBExtend,
1336-
std::optional<unsigned> BinOp = std::nullopt) const;
1335+
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
1336+
TTI::TargetCostKind CostKind) const;
13371337

13381338
/// \return The maximum interleave factor that any transform should try to
13391339
/// perform for this target. This number depends on the level of parallelism

llvm/include/llvm/Analysis/TargetTransformInfoImpl.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -652,12 +652,11 @@ class TargetTransformInfoImplBase {
652652
virtual bool enableWritePrefetching() const { return false; }
653653
virtual bool shouldPrefetchAddressSpace(unsigned AS) const { return !AS; }
654654

655-
virtual InstructionCost
656-
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
657-
Type *AccumType, ElementCount VF,
658-
TTI::PartialReductionExtendKind OpAExtend,
659-
TTI::PartialReductionExtendKind OpBExtend,
660-
std::optional<unsigned> BinOp = std::nullopt) const {
655+
virtual InstructionCost getPartialReductionCost(
656+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
657+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
658+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
659+
TTI::TargetCostKind CostKind) const {
661660
return InstructionCost::getInvalid();
662661
}
663662

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -871,10 +871,11 @@ bool TargetTransformInfo::shouldPrefetchAddressSpace(unsigned AS) const {
871871
InstructionCost TargetTransformInfo::getPartialReductionCost(
872872
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
873873
ElementCount VF, PartialReductionExtendKind OpAExtend,
874-
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp) const {
874+
PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
875+
TTI::TargetCostKind CostKind) const {
875876
return TTIImpl->getPartialReductionCost(Opcode, InputTypeA, InputTypeB,
876877
AccumType, VF, OpAExtend, OpBExtend,
877-
BinOp);
878+
BinOp, CostKind);
878879
}
879880

880881
unsigned TargetTransformInfo::getMaxInterleaveFactor(ElementCount VF) const {

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5395,11 +5395,14 @@ AArch64TTIImpl::getSpliceCost(VectorType *Tp, int Index,
53955395
InstructionCost AArch64TTIImpl::getPartialReductionCost(
53965396
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
53975397
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
5398-
TTI::PartialReductionExtendKind OpBExtend,
5399-
std::optional<unsigned> BinOp) const {
5398+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
5399+
TTI::TargetCostKind CostKind) const {
54005400
InstructionCost Invalid = InstructionCost::getInvalid();
54015401
InstructionCost Cost(TTI::TCC_Basic);
54025402

5403+
if (CostKind != TTI::TCK_RecipThroughput)
5404+
return Invalid;
5405+
54035406
// Sub opcodes currently only occur in chained cases.
54045407
// Independent partial reduction subtractions are still costed as an add
54055408
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)

llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -382,12 +382,11 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
382382
return BaseT::isLegalNTLoad(DataType, Alignment);
383383
}
384384

385-
InstructionCost
386-
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
387-
Type *AccumType, ElementCount VF,
388-
TTI::PartialReductionExtendKind OpAExtend,
389-
TTI::PartialReductionExtendKind OpBExtend,
390-
std::optional<unsigned> BinOp) const override;
385+
InstructionCost getPartialReductionCost(
386+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
387+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
388+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
389+
TTI::TargetCostKind CostKind) const override;
391390

392391
bool enableOrderedReductions() const override { return true; }
393392

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -297,8 +297,8 @@ RISCVTTIImpl::getPopcntSupport(unsigned TyWidth) const {
297297
InstructionCost RISCVTTIImpl::getPartialReductionCost(
298298
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
299299
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
300-
TTI::PartialReductionExtendKind OpBExtend,
301-
std::optional<unsigned> BinOp) const {
300+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
301+
TTI::TargetCostKind CostKind) const {
302302

303303
// zve32x is broken for partial_reduce_umla, but let's make sure we
304304
// don't generate them.
@@ -311,9 +311,8 @@ InstructionCost RISCVTTIImpl::getPartialReductionCost(
311311
Type *Tp = VectorType::get(AccumType, VF.divideCoefficientBy(4));
312312
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(Tp);
313313
// Note: Asuming all vqdot* variants are equal cost
314-
// TODO: Thread CostKind through this API
315-
return LT.first * getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second,
316-
TTI::TCK_RecipThroughput);
314+
return LT.first *
315+
getRISCVInstructionCost(RISCV::VQDOT_VV, LT.second, CostKind);
317316
}
318317

319318
bool RISCVTTIImpl::shouldExpandReduction(const IntrinsicInst *II) const {

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.h

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,12 +100,11 @@ class RISCVTTIImpl final : public BasicTTIImplBase<RISCVTTIImpl> {
100100
TargetTransformInfo::PopcntSupportKind
101101
getPopcntSupport(unsigned TyWidth) const override;
102102

103-
InstructionCost
104-
getPartialReductionCost(unsigned Opcode, Type *InputTypeA, Type *InputTypeB,
105-
Type *AccumType, ElementCount VF,
106-
TTI::PartialReductionExtendKind OpAExtend,
107-
TTI::PartialReductionExtendKind OpBExtend,
108-
std::optional<unsigned> BinOp) const override;
103+
InstructionCost getPartialReductionCost(
104+
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
105+
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
106+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
107+
TTI::TargetCostKind CostKind) const override;
109108

110109
bool shouldExpandReduction(const IntrinsicInst *II) const override;
111110
bool supportsScalableVectors() const override {

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,12 +198,15 @@ InstructionCost WebAssemblyTTIImpl::getVectorInstrCost(
198198
InstructionCost WebAssemblyTTIImpl::getPartialReductionCost(
199199
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
200200
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
201-
TTI::PartialReductionExtendKind OpBExtend,
202-
std::optional<unsigned> BinOp) const {
201+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
202+
TTI::TargetCostKind CostKind) const {
203203
InstructionCost Invalid = InstructionCost::getInvalid();
204204
if (!VF.isFixed() || !ST->hasSIMD128())
205205
return Invalid;
206206

207+
if (CostKind != TTI::TCK_RecipThroughput)
208+
return Invalid;
209+
207210
InstructionCost Cost(TTI::TCC_Basic);
208211

209212
// Possible options:

llvm/lib/Target/WebAssembly/WebAssemblyTargetTransformInfo.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,8 +86,8 @@ class WebAssemblyTTIImpl final : public BasicTTIImplBase<WebAssemblyTTIImpl> {
8686
InstructionCost getPartialReductionCost(
8787
unsigned Opcode, Type *InputTypeA, Type *InputTypeB, Type *AccumType,
8888
ElementCount VF, TTI::PartialReductionExtendKind OpAExtend,
89-
TTI::PartialReductionExtendKind OpBExtend,
90-
std::optional<unsigned> BinOp = std::nullopt) const override;
89+
TTI::PartialReductionExtendKind OpBExtend, std::optional<unsigned> BinOp,
90+
TTI::TargetCostKind CostKind) const override;
9191
TTI::ReductionShuffle
9292
getPreferredExpandedReductionShuffle(const IntrinsicInst *II) const override;
9393

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8240,7 +8240,7 @@ bool VPRecipeBuilder::getScaledReductions(
82408240
[&](ElementCount VF) {
82418241
InstructionCost Cost = TTI->getPartialReductionCost(
82428242
Update->getOpcode(), A->getType(), B->getType(), PHI->getType(),
8243-
VF, OpAExtend, OpBExtend, BinOp->getOpcode());
8243+
VF, OpAExtend, OpBExtend, BinOp->getOpcode(), CM.CostKind);
82448244
return Cost.isValid();
82458245
},
82468246
Range)) {

0 commit comments

Comments
 (0)