Skip to content

Commit 89372f1

Browse files
committed
Bundle partial reductions inside VPMulAccumulateReductionRecipe
1 parent c13d4b3 commit 89372f1

File tree

10 files changed

+686
-793
lines changed

10 files changed

+686
-793
lines changed

llvm/include/llvm/Analysis/TargetTransformInfo.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,8 @@ class TargetTransformInfo {
219219
/// Get the kind of extension that an instruction represents.
220220
static PartialReductionExtendKind
221221
getPartialReductionExtendKind(Instruction *I);
222+
static PartialReductionExtendKind
223+
getPartialReductionExtendKind(Instruction::CastOps ExtOpcode);
222224

223225
/// Construct a TTI object using a type implementing the \c Concept
224226
/// API below.

llvm/lib/Analysis/TargetTransformInfo.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -993,6 +993,19 @@ TargetTransformInfo::getPartialReductionExtendKind(Instruction *I) {
993993
return PR_None;
994994
}
995995

996+
TargetTransformInfo::PartialReductionExtendKind
997+
TargetTransformInfo::getPartialReductionExtendKind(
998+
Instruction::CastOps ExtOpcode) {
999+
switch (ExtOpcode) {
1000+
case Instruction::CastOps::ZExt:
1001+
return PR_ZeroExtend;
1002+
case Instruction::CastOps::SExt:
1003+
return PR_SignExtend;
1004+
default:
1005+
return PR_None;
1006+
}
1007+
}
1008+
9961009
TTI::CastContextHint
9971010
TargetTransformInfo::getCastContextHint(const Instruction *I) {
9981011
if (!I)

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8879,17 +8879,15 @@ VPRecipeBuilder::tryToCreatePartialReduction(Instruction *Reduction,
88798879
ReductionOpcode = Instruction::Add;
88808880
}
88818881

8882+
VPValue *Cond = nullptr;
88828883
if (CM.blockNeedsPredicationForAnyReason(Reduction->getParent())) {
88838884
assert((ReductionOpcode == Instruction::Add ||
88848885
ReductionOpcode == Instruction::Sub) &&
88858886
"Expected an ADD or SUB operation for predicated partial "
88868887
"reductions (because the neutral element in the mask is zero)!");
8887-
VPValue *Mask = getBlockInMask(Reduction->getParent());
8888-
VPValue *Zero =
8889-
Plan.getOrAddLiveIn(ConstantInt::get(Reduction->getType(), 0));
8890-
BinOp = Builder.createSelect(Mask, BinOp, Zero, Reduction->getDebugLoc());
8888+
Cond = getBlockInMask(Reduction->getParent());
88918889
}
8892-
return new VPPartialReductionRecipe(ReductionOpcode, BinOp, Accumulator,
8890+
return new VPPartialReductionRecipe(ReductionOpcode, Accumulator, BinOp, Cond,
88938891
Reduction);
88948892
}
88958893

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2331,21 +2331,21 @@ class VPReductionRecipe : public VPRecipeWithIRFlags {
23312331
/// vector operand are added together and passed to the next iteration as the
23322332
/// next accumulator. After the loop body, the accumulator is reduced to a
23332333
/// scalar value.
2334-
class VPPartialReductionRecipe : public VPSingleDefRecipe {
2334+
class VPPartialReductionRecipe : public VPReductionRecipe {
23352335
unsigned Opcode;
23362336

23372337
public:
23382338
VPPartialReductionRecipe(Instruction *ReductionInst, VPValue *Op0,
2339-
VPValue *Op1)
2340-
: VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1,
2339+
VPValue *Op1, VPValue *Cond)
2340+
: VPPartialReductionRecipe(ReductionInst->getOpcode(), Op0, Op1, Cond,
23412341
ReductionInst) {}
23422342
VPPartialReductionRecipe(unsigned Opcode, VPValue *Op0, VPValue *Op1,
2343-
Instruction *ReductionInst = nullptr)
2344-
: VPSingleDefRecipe(VPDef::VPPartialReductionSC,
2345-
ArrayRef<VPValue *>({Op0, Op1}), ReductionInst),
2343+
VPValue *Cond, Instruction *ReductionInst = nullptr)
2344+
: VPReductionRecipe(VPDef::VPPartialReductionSC, RecurKind::Add,
2345+
FastMathFlags(), ReductionInst,
2346+
ArrayRef<VPValue *>({Op0, Op1}), Cond, false, {}),
23462347
Opcode(Opcode) {
2347-
[[maybe_unused]] auto *AccumulatorRecipe =
2348-
getOperand(1)->getDefiningRecipe();
2348+
[[maybe_unused]] auto *AccumulatorRecipe = getChainOp()->getDefiningRecipe();
23492349
assert((isa<VPReductionPHIRecipe>(AccumulatorRecipe) ||
23502350
isa<VPPartialReductionRecipe>(AccumulatorRecipe)) &&
23512351
"Unexpected operand order for partial reduction recipe");
@@ -2354,7 +2354,7 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
23542354

23552355
VPPartialReductionRecipe *clone() override {
23562356
return new VPPartialReductionRecipe(Opcode, getOperand(0), getOperand(1),
2357-
getUnderlyingInstr());
2357+
getCondOp(), getUnderlyingInstr());
23582358
}
23592359

23602360
VP_CLASSOF_IMPL(VPDef::VPPartialReductionSC)
@@ -2369,14 +2369,16 @@ class VPPartialReductionRecipe : public VPSingleDefRecipe {
23692369
/// Get the binary op's opcode.
23702370
unsigned getOpcode() const { return Opcode; }
23712371

2372+
/// Get the binary op this reduction is applied to.
2373+
VPValue *getBinOp() const { return getOperand(1); }
2374+
23722375
#if !defined(NDEBUG) || defined(LLVM_ENABLE_DUMP)
23732376
/// Print the recipe.
23742377
void print(raw_ostream &O, const Twine &Indent,
23752378
VPSlotTracker &SlotTracker) const override;
23762379
#endif
23772380
};
23782381

2379-
23802382
/// A recipe to represent inloop reduction operations with vector-predication
23812383
/// intrinsics, performing a reduction on a vector operand with the explicit
23822384
/// vector length (EVL) into a scalar value, and adding the result to a chain.
@@ -2497,6 +2499,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
24972499

24982500
Type *ResultTy;
24992501

2502+
/// If the reduction this is based on is a partial reduction.
2503+
bool IsPartialReduction = false;
2504+
25002505
/// For cloning VPMulAccumulateReductionRecipe.
25012506
VPMulAccumulateReductionRecipe(VPMulAccumulateReductionRecipe *MulAcc)
25022507
: VPReductionRecipe(
@@ -2506,7 +2511,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25062511
WrapFlagsTy(MulAcc->hasNoUnsignedWrap(), MulAcc->hasNoSignedWrap()),
25072512
MulAcc->getDebugLoc()),
25082513
ExtOp(MulAcc->getExtOpcode()), IsNonNeg(MulAcc->isNonNeg()),
2509-
ResultTy(MulAcc->getResultType()) {}
2514+
ResultTy(MulAcc->getResultType()),
2515+
IsPartialReduction(MulAcc->isPartialReduction()) {}
25102516

25112517
public:
25122518
VPMulAccumulateReductionRecipe(VPReductionRecipe *R, VPWidenRecipe *Mul,
@@ -2519,7 +2525,8 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25192525
WrapFlagsTy(Mul->hasNoUnsignedWrap(), Mul->hasNoSignedWrap()),
25202526
R->getDebugLoc()),
25212527
ExtOp(Ext0->getOpcode()), IsNonNeg(Ext0->isNonNeg()),
2522-
ResultTy(ResultTy) {
2528+
ResultTy(ResultTy),
2529+
IsPartialReduction(isa<VPPartialReductionRecipe>(R)) {
25232530
assert(RecurrenceDescriptor::getOpcode(getRecurrenceKind()) ==
25242531
Instruction::Add &&
25252532
"The reduction instruction in MulAccumulateteReductionRecipe must "
@@ -2590,6 +2597,9 @@ class VPMulAccumulateReductionRecipe : public VPReductionRecipe {
25902597

25912598
/// Return the non negative flag of the ext recipe.
25922599
bool isNonNeg() const { return IsNonNeg; }
2600+
2601+
/// Return if the underlying reduction recipe is a partial reduction.
2602+
bool isPartialReduction() const { return IsPartialReduction; }
25932603
};
25942604

25952605
/// VPReplicateRecipe replicates a given instruction producing multiple scalar

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -287,14 +287,9 @@ InstructionCost
287287
VPPartialReductionRecipe::computeCost(ElementCount VF,
288288
VPCostContext &Ctx) const {
289289
std::optional<unsigned> Opcode = std::nullopt;
290-
VPValue *BinOp = getOperand(0);
290+
VPValue *BinOp = getBinOp();
291291

292-
// If the partial reduction is predicated, a select will be operand 0 rather
293-
// than the binary op
294292
using namespace llvm::VPlanPatternMatch;
295-
if (match(getOperand(0), m_Select(m_VPValue(), m_VPValue(), m_VPValue())))
296-
BinOp = BinOp->getDefiningRecipe()->getOperand(1);
297-
298293
// If BinOp is a negation, use the side effect of match to assign the actual
299294
// binary operation to BinOp
300295
match(BinOp, m_Binary<Instruction::Sub>(m_SpecificInt(0), m_VPValue(BinOp)));
@@ -338,12 +333,18 @@ void VPPartialReductionRecipe::execute(VPTransformState &State) {
338333
assert(getOpcode() == Instruction::Add &&
339334
"Unhandled partial reduction opcode");
340335

341-
Value *BinOpVal = State.get(getOperand(0));
342-
Value *PhiVal = State.get(getOperand(1));
336+
Value *BinOpVal = State.get(getBinOp());
337+
Value *PhiVal = State.get(getChainOp());
343338
assert(PhiVal && BinOpVal && "Phi and Mul must be set");
344339

345340
Type *RetTy = PhiVal->getType();
346341

342+
/// Mask the bin op output.
343+
if (VPValue *Cond = getCondOp()) {
344+
Value *Zero = ConstantInt::get(BinOpVal->getType(), 0);
345+
BinOpVal = Builder.CreateSelect(State.get(Cond), BinOpVal, Zero);
346+
}
347+
347348
CallInst *V = Builder.CreateIntrinsic(
348349
RetTy, Intrinsic::experimental_vector_partial_reduce_add,
349350
{PhiVal, BinOpVal}, nullptr, "partial.reduce");
@@ -2432,6 +2433,14 @@ VPExtendedReductionRecipe::computeCost(ElementCount VF,
24322433
InstructionCost
24332434
VPMulAccumulateReductionRecipe::computeCost(ElementCount VF,
24342435
VPCostContext &Ctx) const {
2436+
if (isPartialReduction()) {
2437+
return Ctx.TTI.getPartialReductionCost(
2438+
Instruction::Add, Ctx.Types.inferScalarType(getVecOp0()),
2439+
Ctx.Types.inferScalarType(getVecOp1()), getResultType(), VF,
2440+
TTI::getPartialReductionExtendKind(getExtOpcode()),
2441+
TTI::getPartialReductionExtendKind(getExtOpcode()), Instruction::Mul);
2442+
}
2443+
24352444
Type *RedTy = Ctx.Types.inferScalarType(this);
24362445
auto *SrcVecTy =
24372446
cast<VectorType>(toVectorTy(Ctx.Types.inferScalarType(getVecOp0()), VF));
@@ -2509,6 +2518,8 @@ void VPMulAccumulateReductionRecipe::print(raw_ostream &O, const Twine &Indent,
25092518
O << " = ";
25102519
getChainOp()->printAsOperand(O, SlotTracker);
25112520
O << " + ";
2521+
if (isPartialReduction())
2522+
O << "partial.";
25122523
O << "reduce."
25132524
<< Instruction::getOpcodeName(
25142525
RecurrenceDescriptor::getOpcode(getRecurrenceKind()))

llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2158,9 +2158,14 @@ expandVPMulAccumulateReduction(VPMulAccumulateReductionRecipe *MulAcc) {
21582158
Mul->insertBefore(MulAcc);
21592159

21602160
// Generate VPReductionRecipe.
2161-
auto *Red = new VPReductionRecipe(
2162-
MulAcc->getRecurrenceKind(), FastMathFlags(), MulAcc->getChainOp(), Mul,
2163-
MulAcc->getCondOp(), MulAcc->isOrdered(), MulAcc->getDebugLoc());
2161+
VPReductionRecipe *Red = nullptr;
2162+
if (MulAcc->isPartialReduction())
2163+
Red = new VPPartialReductionRecipe(Instruction::Add, MulAcc->getChainOp(),
2164+
Mul, MulAcc->getCondOp());
2165+
else
2166+
Red = new VPReductionRecipe(MulAcc->getRecurrenceKind(), FastMathFlags(),
2167+
MulAcc->getChainOp(), Mul, MulAcc->getCondOp(),
2168+
MulAcc->isOrdered(), MulAcc->getDebugLoc());
21642169
Red->insertBefore(MulAcc);
21652170

21662171
MulAcc->replaceAllUsesWith(Red);
@@ -2432,12 +2437,39 @@ static void tryToCreateAbstractReductionRecipe(VPReductionRecipe *Red,
24322437
Red->replaceAllUsesWith(AbstractR);
24332438
}
24342439

2440+
static void
2441+
tryToCreateAbstractPartialReductionRecipe(VPPartialReductionRecipe *PRed) {
2442+
if (PRed->getOpcode() != Instruction::Add)
2443+
return;
2444+
2445+
VPRecipeBase *BinOpR = PRed->getBinOp()->getDefiningRecipe();
2446+
auto *BinOp = dyn_cast<VPWidenRecipe>(BinOpR);
2447+
if (!BinOp || BinOp->getOpcode() != Instruction::Mul)
2448+
return;
2449+
2450+
auto *Ext0 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(0));
2451+
auto *Ext1 = dyn_cast<VPWidenCastRecipe>(BinOp->getOperand(1));
2452+
// TODO: Make work with extends of different signedness
2453+
if (!Ext0 || Ext0->hasMoreThanOneUniqueUser() || !Ext1 ||
2454+
Ext1->hasMoreThanOneUniqueUser() ||
2455+
Ext0->getOpcode() != Ext1->getOpcode())
2456+
return;
2457+
2458+
auto *AbstractR = new VPMulAccumulateReductionRecipe(PRed, BinOp, Ext0, Ext1,
2459+
Ext0->getResultType());
2460+
AbstractR->insertBefore(PRed);
2461+
PRed->replaceAllUsesWith(AbstractR);
2462+
PRed->eraseFromParent();
2463+
}
2464+
24352465
void VPlanTransforms::convertToAbstractRecipes(VPlan &Plan, VPCostContext &Ctx,
24362466
VFRange &Range) {
24372467
for (VPBasicBlock *VPBB : VPBlockUtils::blocksOnly<VPBasicBlock>(
24382468
vp_depth_first_deep(Plan.getVectorLoopRegion()))) {
24392469
for (VPRecipeBase &R : make_early_inc_range(*VPBB)) {
2440-
if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
2470+
if (auto *PRed = dyn_cast<VPPartialReductionRecipe>(&R))
2471+
tryToCreateAbstractPartialReductionRecipe(PRed);
2472+
else if (auto *Red = dyn_cast<VPReductionRecipe>(&R))
24412473
tryToCreateAbstractReductionRecipe(Red, Ctx, Range);
24422474
}
24432475
}

0 commit comments

Comments
 (0)