Skip to content

Commit 40e44cf

Browse files
committed
Create VecOperandInfo
1 parent 10c4727 commit 40e44cf

File tree

3 files changed

+47
-50
lines changed

3 files changed

+47
-50
lines changed

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 22 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2493,13 +2493,6 @@ class VPExtendedReductionRecipe : public VPReductionRecipe {
24932493
/// recipe is abstract and needs to be lowered to concrete recipes before
24942494
/// codegen. The Operands are {ChainOp, VecOp1, VecOp2, [Condition]}.
24952495
class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
2496-
/// Opcodes of the extend recipes.
2497-
Instruction::CastOps ExtOp0;
2498-
Instruction::CastOps ExtOp1;
2499-
2500-
/// Non-neg flags of the extend recipe.
2501-
bool IsNonNeg0 = false;
2502-
bool IsNonNeg1 = false;
25032496

25042497
Type *ResultTy;
25052498

@@ -2514,10 +2507,11 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25142507
MulAcc->getCondOp(), MulAcc->isOrdered(),
25152508
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
25162509
MulAcc->getDebugLoc()),
2517-
ExtOp0(MulAcc->getExt0Opcode()), ExtOp1(MulAcc->getExt1Opcode()),
2518-
IsNonNeg0(MulAcc->isNonNeg0()), IsNonNeg1(MulAcc->isNonNeg1()),
25192510
ResultTy(MulAcc->getResultType()),
2520-
IsPartialReduction(MulAcc->isPartialReduction()) {}
2511+
IsPartialReduction(MulAcc->isPartialReduction()) {
2512+
VecOpInfo[0] = MulAcc->getVecOp0Info();
2513+
VecOpInfo[1] = MulAcc->getVecOp1Info();
2514+
}
25212515

25222516
public:
25232517
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2529,14 +2523,14 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25292523
R->getCondOp(), R->isOrdered(),
25302524
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
25312525
R->getDebugLoc()),
2532-
ExtOp0(Ext0->getOpcode()), ExtOp1(Ext1->getOpcode()),
2533-
IsNonNeg0(Ext0->isNonNeg()), IsNonNeg1(Ext1->isNonNeg()),
25342526
ResultTy(ResultTy),
25352527
IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
25362528
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
25372529
Instruction::Add &&
25382530
"The reduction instruction in MulAccumulateteReductionRecipe must "
25392531
"be Add");
2532+
VecOpInfo[0] = {Ext0->getOpcode(), Ext0->isNonNeg()};
2533+
VecOpInfo[1] = {Ext1->getOpcode(), Ext1->isNonNeg()};
25402534
}
25412535

25422536
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul)
@@ -2545,15 +2539,20 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25452539
{R->getChainOp(), Mul->getOperand(0), Mul->getOperand(1)},
25462540
R->getCondOp(), R->isOrdered(),
25472541
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
2548-
R->getDebugLoc()),
2549-
ExtOp0(Instruction::CastOps::CastOpsEnd),
2550-
ExtOp1(Instruction::CastOps::CastOpsEnd) {
2542+
R->getDebugLoc()) {
25512543
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
25522544
Instruction::Add &&
25532545
"The reduction instruction in MulAccumulateReductionRecipe must be "
25542546
"Add");
25552547
}
25562548

2549+
struct VecOperandInfo {
2550+
/// The operand's extend opcode.
2551+
Instruction::CastOps ExtOp{Instruction::CastOps::CastOpsEnd};
2552+
/// Non-neg portion of the operand's flags.
2553+
bool IsNonNeg = false;
2554+
};
2555+
25572556
~VPMulAccumulateReductionRecipe() override = default;
25582557

25592558
VPMulAccumulateReductionRecipe *clone() override {
@@ -2591,29 +2590,21 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25912590
VPValue *getVecOp1() const { return getOperand(2); }
25922591

25932592
/// Return if this MulAcc recipe contains extend instructions.
2594-
bool isExtended() const { return ExtOp0 != Instruction::CastOps::CastOpsEnd; }
2593+
bool isExtended() const {
2594+
return getVecOp0Info().ExtOp != Instruction::CastOps::CastOpsEnd;
2595+
}
25952596

25962597
/// Return if the operands of mul instruction come from same extend.
25972598
bool isSameExtendVal() const { return getVecOp0() == getVecOp1(); }
25982599

2599-
/// Return the opcode of the underlying extends.
2600-
Instruction::CastOps getExt0Opcode() const { return ExtOp0; }
2601-
Instruction::CastOps getExt1Opcode() const { return ExtOp1; }
2602-
2603-
/// Return if the first extend's opcode is ZExt.
2604-
bool isZExt0() const { return ExtOp0 == Instruction::CastOps::ZExt; }
2605-
2606-
/// Return if the second extend's opcode is ZExt.
2607-
bool isZExt1() const { return ExtOp1 == Instruction::CastOps::ZExt; }
2608-
2609-
/// Return the non negative flag of the first ext recipe.
2610-
bool isNonNeg0() const { return IsNonNeg0; }
2611-
2612-
/// Return the non negative flag of the second ext recipe.
2613-
bool isNonNeg1() const { return IsNonNeg1; }
2600+
VecOperandInfo getVecOp0Info() const { return VecOpInfo[0]; }
2601+
VecOperandInfo getVecOp1Info() const { return VecOpInfo[1]; }
26142602

26152603
/// Return if the underlying reduction recipe is a partial reduction.
26162604
bool isPartialReduction() const { return IsPartialReduction; }
2605+
2606+
protected:
2607+
VecOperandInfo VecOpInfo[2];
26172608
};
26182609

26192610
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2434,19 +2434,22 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
24342434
InstructionCost
24352435
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
24362436
VPCostContext &Ctx) const {
2437+
VecOperandInfo Op0Info = getVecOp0Info();
2438+
VecOperandInfo Op1Info = getVecOp1Info();
24372439
if (isPartialReduction()) {
24382440
return Ctx.TTI.getPartialReductionCost(
24392441
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
24402442
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
2441-
TTI::getPartialReductionExtendKind(getExt0Opcode()),
2442-
TTI::getPartialReductionExtendKind(getExt1Opcode()), Instruction::Mul);
2443+
TTI::getPartialReductionExtendKind(Op0Info.ExtOp),
2444+
TTI::getPartialReductionExtendKind(Op1Info.ExtOp), Instruction::Mul);
24432445
}
24442446

24452447
Type *RedTy = Ctx.Types.inferScalarType(this);
24462448
auto *SrcVecTy =
24472449
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
2448-
return Ctx.TTI.getMulAccReductionCost(isZExt0(), RedTy, SrcVecTy,
2449-
Ctx.CostKind);
2450+
return Ctx.TTI.getMulAccReductionCost(Op0Info.ExtOp ==
2451+
Instruction::CastOps::ZExt,
2452+
RedTy, SrcVecTy, Ctx.CostKind);
24502453
}
24512454

24522455
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
@@ -2514,6 +2517,8 @@ void VPExtendedReductionRecipe::print(raw_ostream &O, const Twine &Indent,
25142517

25152518
void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
25162519
VPSlotTracker &SlotTracker) const {
2520+
VecOperandInfo Op0Info = getVecOp0Info();
2521+
VecOperandInfo Op1Info = getVecOp1Info();
25172522
O << Indent << "MULACC-REDUCE ";
25182523
printAsOperand(O, SlotTracker);
25192524
O << " = ";
@@ -2532,7 +2537,7 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
25322537
getVecOp0()->printAsOperand(O, SlotTracker);
25332538
if (isExtended()) {
25342539
O << " ";
2535-
if (isZExt0())
2540+
if (Op0Info.ExtOp == Instruction::CastOps::ZExt)
25362541
O << "zero-";
25372542
else
25382543
O << "sign-";
@@ -2542,7 +2547,7 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
25422547
getVecOp1()->printAsOperand(O, SlotTracker);
25432548
if (isExtended()) {
25442549
O << " ";
2545-
if (isZExt1())
2550+
if (Op1Info.ExtOp == Instruction::CastOps::ZExt)
25462551
O << "zero-";
25472552
else
25482553
O << "sign-";

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2120,28 +2120,29 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
21202120
// reduce.add(ext(mul(ext, ext))) to reduce.add(mul(ext, ext)).
21212121
VPValue *Op0, *Op1;
21222122
if (MulAcc->isExtended()) {
2123+
VPMulAccumulateReductionRecipe::VecOperandInfo Op0Info =
2124+
MulAcc->getVecOp0Info();
2125+
VPMulAccumulateReductionRecipe::VecOperandInfo Op1Info =
2126+
MulAcc->getVecOp1Info();
21232127
Type *RedTy = MulAcc->getResultType();
2124-
if (MulAcc->isZExt0())
2125-
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
2126-
RedTy, MulAcc->isNonNeg0(),
2127-
MulAcc->getDebugLoc());
2128+
if (Op0Info.ExtOp == Instruction::CastOps::ZExt)
2129+
Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
2130+
Op0Info.IsNonNeg, MulAcc->getDebugLoc());
21282131
else
2129-
Op0 = new VPWidenCastRecipe(MulAcc->getExt0Opcode(), MulAcc->getVecOp0(),
2130-
RedTy, MulAcc->getDebugLoc());
2132+
Op0 = new VPWidenCastRecipe(Op0Info.ExtOp, MulAcc->getVecOp0(), RedTy,
2133+
MulAcc->getDebugLoc());
21312134
Op0->getDefiningRecipe()->insertBefore(MulAcc);
21322135
// Prevent reduce.add(mul(ext(A), ext(A))) generate duplicate
21332136
// VPWidenCastRecipe.
21342137
if (MulAcc->getVecOp0() == MulAcc->getVecOp1()) {
21352138
Op1 = Op0;
21362139
} else {
2137-
if (MulAcc->isZExt1())
2138-
Op1 = new VPWidenCastRecipe(MulAcc->getExt1Opcode(),
2139-
MulAcc->getVecOp1(), RedTy,
2140-
MulAcc->isNonNeg1(), MulAcc->getDebugLoc());
2140+
if (Op1Info.ExtOp == Instruction::CastOps::ZExt)
2141+
Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2142+
Op1Info.IsNonNeg, MulAcc->getDebugLoc());
21412143
else
2142-
Op1 =
2143-
new VPWidenCastRecipe(MulAcc->getExt1Opcode(), MulAcc->getVecOp1(),
2144-
RedTy, MulAcc->getDebugLoc());
2144+
Op1 = new VPWidenCastRecipe(Op1Info.ExtOp, MulAcc->getVecOp1(), RedTy,
2145+
MulAcc->getDebugLoc());
21452146
Op1->getDefiningRecipe()->insertBefore(MulAcc);
21462147
}
21472148
} else {

0 commit comments

Comments
 (0)