Skip to content

Commit 2fbdc7c

Browse files
committed
!fixup, Address comments and fix VPReductionRecipe::computeCost
Note that we should use std::nullopt when quering cost of non-FPMathOperator instructions.
1 parent ca5db10 commit 2fbdc7c

File tree

3 files changed

+30
-33
lines changed

3 files changed

+30
-33
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 22 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -2437,14 +2437,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
24372437
/// Opcode of the extend recipe will be lowered to.
24382438
Instruction::CastOps ExtOp;
24392439

2440-
public:
2441-
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2442-
: VPReductionRecipe(VPDef::VPExtendedReductionSC,
2443-
R->getRecurrenceDescriptor(),
2444-
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2445-
R->isOrdered(), Ext->isNonNeg(), Ext->getDebugLoc()),
2446-
ExtOp(Ext->getOpcode()) {}
2447-
24482440
/// For cloning VPExtendedReductionRecipe.
24492441
VPExtendedReductionRecipe(VPExtendedReductionRecipe *ExtRed)
24502442
: VPReductionRecipe(
@@ -2453,6 +2445,14 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
24532445
ExtRed->isOrdered(), ExtRed->isNonNeg(), ExtRed->getDebugLoc()),
24542446
ExtOp(ExtRed->getExtOpcode()) {}
24552447

2448+
public:
2449+
VPExtendedReductionRecipe(VPReductionRecipe *R, VPWidenCastRecipe *Ext)
2450+
: VPReductionRecipe(VPDef::VPExtendedReductionSC,
2451+
R->getRecurrenceDescriptor(),
2452+
{R->getChainOp(), Ext->getOperand(0)}, R->getCondOp(),
2453+
R->isOrdered(), Ext->isNonNeg(), Ext->getDebugLoc()),
2454+
ExtOp(Ext->getOpcode()) {}
2455+
24562456
~VPExtendedReductionRecipe() override = default;
24572457

24582458
VPExtendedReductionRecipe *clone() override {
@@ -2500,8 +2500,16 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25002500
/// Non-neg flag of the extend recipe.
25012501
bool IsNonNeg = false;
25022502

2503-
/// Is this multiply-accumulate-reduction recipe contains extend?
2504-
bool IsExtended = false;
2503+
/// For cloning VPMulAccumulateReductionRecipe.
2504+
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2505+
: VPReductionRecipe(
2506+
VPDef::VPMulAccumulateReductionSC,
2507+
MulAcc->getRecurrenceDescriptor(),
2508+
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2509+
MulAcc->getCondOp(), MulAcc->isOrdered(),
2510+
MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
2511+
MulAcc->getDebugLoc()),
2512+
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()) {}
25052513

25062514
public:
25072515
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2516,32 +2524,20 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25162524
assert(getRecurrenceDescriptor().getOpcode() == Instruction::Add &&
25172525
"The reduction instruction in MulAccumulateteReductionRecipe must "
25182526
"be Add");
2519-
IsExtended = true;
25202527
}
25212528

25222529
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
25232530
: VPReductionRecipe(
25242531
VPDef::VPMulAccumulateReductionSC, R->getRecurrenceDescriptor(),
25252532
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
25262533
R->getCondOp(), R->isOrdered(), Mul->hasNoUnsignedWrap(),
2527-
Mul->hasNoSignedWrap(), R->getDebugLoc()) {
2534+
Mul->hasNoSignedWrap(), R->getDebugLoc()),
2535+
ExtOp(Instruction::CastOps::CastOpsEnd) {
25282536
assert(getRecurrenceDescriptor().getOpcode() == Instruction::Add &&
25292537
"The reduction instruction in MulAccumulateReductionRecipe must be "
25302538
"Add");
25312539
}
25322540

2533-
/// For cloning VPMulAccumulateReductionRecipe.
2534-
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
2535-
: VPReductionRecipe(
2536-
VPDef::VPMulAccumulateReductionSC,
2537-
MulAcc->getRecurrenceDescriptor(),
2538-
{MulAcc->getChainOp(), MulAcc->getVecOp0(), MulAcc->getVecOp1()},
2539-
MulAcc->getCondOp(), MulAcc->isOrdered(),
2540-
MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),
2541-
MulAcc->getDebugLoc()),
2542-
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2543-
IsExtended(MulAcc->isExtended()) {}
2544-
25452541
~VPMulAccumulateReductionRecipe() override = default;
25462542

25472543
VPMulAccumulateReductionRecipe *clone() override {
@@ -2571,7 +2567,7 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25712567
VPValue *getVecOp1() const { return getOperand(2); }
25722568

25732569
/// Return if this MulAcc recipe contains extend instructions.
2574-
bool isExtended() const { return IsExtended; }
2570+
bool isExtended() const { return ExtOp != Instruction::CastOps::CastOpsEnd; }
25752571

25762572
/// Return if the operands of mul instruction come from same extend.
25772573
bool isSameExtend() const { return getVecOp0() == getVecOp1(); }
@@ -2580,11 +2576,7 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25802576
Instruction::CastOps getExtOpcode() const { return ExtOp; }
25812577

25822578
/// Return if the extend opcode is ZExt.
2583-
bool isZExt() const {
2584-
if (!isExtended())
2585-
return true;
2586-
return ExtOp == Instruction::CastOps::ZExt;
2587-
}
2579+
bool isZExt() const { return ExtOp == Instruction::CastOps::ZExt; }
25882580

25892581
/// Return the non negative flag of the ext recipe.
25902582
bool isNonNeg() const { return IsNonNeg; }

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2399,8 +2399,13 @@ InstructionCost VPReductionRecipe::computeCost(ElementCount VF,
23992399
Ctx.TTI.getMinMaxReductionCost(Id, VectorTy, FMFs, Ctx.CostKind);
24002400
}
24012401

2402-
return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2403-
Ctx.CostKind);
2402+
if (ElementTy->isFloatingPointTy())
2403+
return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, FMFs,
2404+
Ctx.CostKind);
2405+
// Cannot get correct cost when quering TTI with FMFs not contains `reassoc`
2406+
// for non-FP reductions.
2407+
return Ctx.TTI.getArithmeticReductionCost(Opcode, VectorTy, std::nullopt,
2408+
Ctx.CostKind);
24042409
}
24052410

24062411
InstructionCost

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2146,7 +2146,7 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
21462146
}
21472147

21482148
// Generate VPWidenRecipe.
2149-
std::array<VPValue *, 2> MulOps = {Op0, Op1};
2149+
ArrayRef<VPValue *> MulOps = {Op0, Op1};
21502150
auto *Mul = new VPWidenRecipe(
21512151
Instruction::Mul, make_range(MulOps.begin(), MulOps.end()),
21522152
MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap(),

0 commit comments

Comments
 (0)