Skip to content

Commit ac4a38e

Browse files
[SLP] Emit reduction instead of 2 extracts + scalar op, when vectorizing operands (#147583)
Added emission of the 2-element reduction instead of 2 extracts + scalar op, when trying to vectorize operands of the instruction, if it is more profitable.
1 parent 20daa73 commit ac4a38e

File tree

8 files changed

+245
-129
lines changed

8 files changed

+245
-129
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 77 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21722,6 +21722,8 @@ class HorizontalReduction {
2172221722
/// Checks if the optimization of original scalar identity operations on
2172321723
/// matched horizontal reductions is enabled and allowed.
2172421724
bool IsSupportedHorRdxIdentityOp = false;
21725+
/// The minimum number of the reduced values.
21726+
const unsigned ReductionLimit = VectorizeNonPowerOf2 ? 3 : 4;
2172521727
/// Contains vector values for reduction including their scale factor and
2172621728
/// signedness.
2172721729
SmallVector<std::tuple<Value *, unsigned, bool>> VectorValuesAndScales;
@@ -21740,7 +21742,8 @@ class HorizontalReduction {
2174021742
}
2174121743

2174221744
/// Checks if instruction is associative and can be vectorized.
21743-
static bool isVectorizable(RecurKind Kind, Instruction *I) {
21745+
static bool isVectorizable(RecurKind Kind, Instruction *I,
21746+
bool TwoElementReduction = false) {
2174421747
if (Kind == RecurKind::None)
2174521748
return false;
2174621749

@@ -21749,6 +21752,10 @@ class HorizontalReduction {
2174921752
isBoolLogicOp(I))
2175021753
return true;
2175121754

21755+
// No need to check for associativity, if 2 reduced values.
21756+
if (TwoElementReduction)
21757+
return true;
21758+
2175221759
if (Kind == RecurKind::FMax || Kind == RecurKind::FMin) {
2175321760
// FP min/max are associative except for NaN and -0.0. We do not
2175421761
// have to rule out -0.0 here because the intrinsic semantics do not
@@ -22020,6 +22027,27 @@ class HorizontalReduction {
2202022027

2202122028
public:
2202222029
HorizontalReduction() = default;
22030+
HorizontalReduction(Instruction *I, ArrayRef<Value *> Ops)
22031+
: ReductionRoot(I), ReductionLimit(2) {
22032+
RdxKind = HorizontalReduction::getRdxKind(I);
22033+
ReductionOps.emplace_back().push_back(I);
22034+
ReducedVals.emplace_back().assign(Ops.begin(), Ops.end());
22035+
for (Value *V : Ops)
22036+
ReducedValsToOps[V].push_back(I);
22037+
}
22038+
22039+
bool matchReductionForOperands() const {
22040+
// Analyze "regular" integer/FP types for reductions - no target-specific
22041+
// types or pointers.
22042+
assert(ReductionRoot && "Reduction root is not set!");
22043+
if (!isVectorizable(RdxKind, cast<Instruction>(ReductionRoot),
22044+
all_of(ReducedVals, [](ArrayRef<Value *> Ops) {
22045+
return Ops.size() == 2;
22046+
})))
22047+
return false;
22048+
22049+
return true;
22050+
}
2202322051

2202422052
/// Try to find a reduction tree.
2202522053
bool matchAssociativeReduction(BoUpSLP &R, Instruction *Root,
@@ -22187,7 +22215,6 @@ class HorizontalReduction {
2218722215
/// Attempt to vectorize the tree found by matchAssociativeReduction.
2218822216
Value *tryToReduce(BoUpSLP &V, const DataLayout &DL, TargetTransformInfo *TTI,
2218922217
const TargetLibraryInfo &TLI, AssumptionCache *AC) {
22190-
const unsigned ReductionLimit = VectorizeNonPowerOf2 ? 3 : 4;
2219122218
constexpr unsigned RegMaxNumber = 4;
2219222219
constexpr unsigned RedValsMaxNumber = 128;
2219322220
// If there are a sufficient number of reduction values, reduce
@@ -23736,15 +23763,60 @@ bool SLPVectorizerPass::tryToVectorize(Instruction *I, BoUpSLP &R) {
2373623763
Candidates.emplace_back(A1, B);
2373723764
}
2373823765

23766+
auto TryToReduce = [this, &R, &TTI = *TTI](Instruction *Inst,
23767+
ArrayRef<Value *> Ops) {
23768+
if (!isReductionCandidate(Inst))
23769+
return false;
23770+
Type *Ty = Inst->getType();
23771+
if (!isValidElementType(Ty) || Ty->isPointerTy())
23772+
return false;
23773+
HorizontalReduction HorRdx(Inst, Ops);
23774+
if (!HorRdx.matchReductionForOperands())
23775+
return false;
23776+
// Check the cost of operations.
23777+
VectorType *VecTy = getWidenedType(Ty, Ops.size());
23778+
constexpr TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
23779+
InstructionCost ScalarCost =
23780+
TTI.getScalarizationOverhead(
23781+
VecTy, APInt::getAllOnes(getNumElements(VecTy)), /*Insert=*/false,
23782+
/*Extract=*/true, CostKind) +
23783+
TTI.getInstructionCost(Inst, CostKind);
23784+
InstructionCost RedCost;
23785+
switch (::getRdxKind(Inst)) {
23786+
case RecurKind::Add:
23787+
case RecurKind::Mul:
23788+
case RecurKind::Or:
23789+
case RecurKind::And:
23790+
case RecurKind::Xor:
23791+
case RecurKind::FAdd:
23792+
case RecurKind::FMul: {
23793+
FastMathFlags FMF;
23794+
if (auto *FPCI = dyn_cast<FPMathOperator>(Inst))
23795+
FMF = FPCI->getFastMathFlags();
23796+
RedCost = TTI.getArithmeticReductionCost(Inst->getOpcode(), VecTy, FMF,
23797+
CostKind);
23798+
break;
23799+
}
23800+
default:
23801+
return false;
23802+
}
23803+
if (RedCost >= ScalarCost)
23804+
return false;
23805+
23806+
return HorRdx.tryToReduce(R, *DL, &TTI, *TLI, AC) != nullptr;
23807+
};
2373923808
if (Candidates.size() == 1)
23740-
return tryToVectorizeList({Op0, Op1}, R);
23809+
return TryToReduce(I, {Op0, Op1}) || tryToVectorizeList({Op0, Op1}, R);
2374123810

2374223811
// We have multiple options. Try to pick the single best.
2374323812
std::optional<int> BestCandidate = R.findBestRootPair(Candidates);
2374423813
if (!BestCandidate)
2374523814
return false;
23746-
return tryToVectorizeList(
23747-
{Candidates[*BestCandidate].first, Candidates[*BestCandidate].second}, R);
23815+
return TryToReduce(I, {Candidates[*BestCandidate].first,
23816+
Candidates[*BestCandidate].second}) ||
23817+
tryToVectorizeList({Candidates[*BestCandidate].first,
23818+
Candidates[*BestCandidate].second},
23819+
R);
2374823820
}
2374923821

2375023822
bool SLPVectorizerPass::vectorizeRootInstruction(PHINode *P, Instruction *Root,

llvm/test/Transforms/SLPVectorizer/AArch64/commute.ll

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ define void @test1(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
1616
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
1717
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
1818
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
19-
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x float> [[TMP6]], i32 0
20-
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x float> [[TMP6]], i32 1
21-
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[TMP7]], [[TMP8]]
19+
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
2220
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
2321
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
2422
; CHECK: for.end27:
@@ -57,9 +55,7 @@ define void @test2(ptr nocapture readonly %J, i32 %xmin, i32 %ymin) {
5755
; CHECK-NEXT: [[TMP4:%.*]] = load <2 x float>, ptr [[J:%.*]], align 4
5856
; CHECK-NEXT: [[TMP5:%.*]] = fsub fast <2 x float> [[TMP2]], [[TMP4]]
5957
; CHECK-NEXT: [[TMP6:%.*]] = fmul fast <2 x float> [[TMP5]], [[TMP5]]
60-
; CHECK-NEXT: [[TMP7:%.*]] = extractelement <2 x float> [[TMP6]], i32 0
61-
; CHECK-NEXT: [[TMP8:%.*]] = extractelement <2 x float> [[TMP6]], i32 1
62-
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[TMP8]], [[TMP7]]
58+
; CHECK-NEXT: [[ADD:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP6]])
6359
; CHECK-NEXT: [[CMP:%.*]] = fcmp oeq float [[ADD]], 0.000000e+00
6460
; CHECK-NEXT: br i1 [[CMP]], label [[FOR_BODY3_LR_PH]], label [[FOR_END27:%.*]]
6561
; CHECK: for.end27:

llvm/test/Transforms/SLPVectorizer/AArch64/reduce-fadd.ll

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,19 @@
33
; RUN: opt < %s -S -passes=slp-vectorizer -mtriple=aarch64-unknown-linux -mattr=+fullfp16 | FileCheck %s --check-prefixes=CHECK,FULLFP16
44

55
define half @reduce_fast_half2(<2 x half> %vec2) {
6-
; CHECK-LABEL: define half @reduce_fast_half2(
7-
; CHECK-SAME: <2 x half> [[VEC2:%.*]]) #[[ATTR0:[0-9]+]] {
8-
; CHECK-NEXT: [[ENTRY:.*:]]
9-
; CHECK-NEXT: [[ELT0:%.*]] = extractelement <2 x half> [[VEC2]], i64 0
10-
; CHECK-NEXT: [[ELT1:%.*]] = extractelement <2 x half> [[VEC2]], i64 1
11-
; CHECK-NEXT: [[ADD1:%.*]] = fadd fast half [[ELT1]], [[ELT0]]
12-
; CHECK-NEXT: ret half [[ADD1]]
6+
; NOFP16-LABEL: define half @reduce_fast_half2(
7+
; NOFP16-SAME: <2 x half> [[VEC2:%.*]]) #[[ATTR0:[0-9]+]] {
8+
; NOFP16-NEXT: [[ENTRY:.*:]]
9+
; NOFP16-NEXT: [[ELT0:%.*]] = extractelement <2 x half> [[VEC2]], i64 0
10+
; NOFP16-NEXT: [[ELT1:%.*]] = extractelement <2 x half> [[VEC2]], i64 1
11+
; NOFP16-NEXT: [[ADD1:%.*]] = fadd fast half [[ELT1]], [[ELT0]]
12+
; NOFP16-NEXT: ret half [[ADD1]]
13+
;
14+
; FULLFP16-LABEL: define half @reduce_fast_half2(
15+
; FULLFP16-SAME: <2 x half> [[VEC2:%.*]]) #[[ATTR0:[0-9]+]] {
16+
; FULLFP16-NEXT: [[ENTRY:.*:]]
17+
; FULLFP16-NEXT: [[TMP0:%.*]] = call fast half @llvm.vector.reduce.fadd.v2f16(half 0xH0000, <2 x half> [[VEC2]])
18+
; FULLFP16-NEXT: ret half [[TMP0]]
1319
;
1420
entry:
1521
%elt0 = extractelement <2 x half> %vec2, i64 0
@@ -20,7 +26,7 @@ entry:
2026

2127
define half @reduce_half2(<2 x half> %vec2) {
2228
; CHECK-LABEL: define half @reduce_half2(
23-
; CHECK-SAME: <2 x half> [[VEC2:%.*]]) #[[ATTR0]] {
29+
; CHECK-SAME: <2 x half> [[VEC2:%.*]]) #[[ATTR0:[0-9]+]] {
2430
; CHECK-NEXT: [[ENTRY:.*:]]
2531
; CHECK-NEXT: [[ELT0:%.*]] = extractelement <2 x half> [[VEC2]], i64 0
2632
; CHECK-NEXT: [[ELT1:%.*]] = extractelement <2 x half> [[VEC2]], i64 1
@@ -269,9 +275,7 @@ define float @reduce_fast_float2(<2 x float> %vec2) {
269275
; CHECK-LABEL: define float @reduce_fast_float2(
270276
; CHECK-SAME: <2 x float> [[VEC2:%.*]]) #[[ATTR0]] {
271277
; CHECK-NEXT: [[ENTRY:.*:]]
272-
; CHECK-NEXT: [[ELT0:%.*]] = extractelement <2 x float> [[VEC2]], i64 0
273-
; CHECK-NEXT: [[ELT1:%.*]] = extractelement <2 x float> [[VEC2]], i64 1
274-
; CHECK-NEXT: [[ADD1:%.*]] = fadd fast float [[ELT1]], [[ELT0]]
278+
; CHECK-NEXT: [[ADD1:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[VEC2]])
275279
; CHECK-NEXT: ret float [[ADD1]]
276280
;
277281
entry:
@@ -409,9 +413,7 @@ define double @reduce_fast_double2(<2 x double> %vec2) {
409413
; CHECK-LABEL: define double @reduce_fast_double2(
410414
; CHECK-SAME: <2 x double> [[VEC2:%.*]]) #[[ATTR0]] {
411415
; CHECK-NEXT: [[ENTRY:.*:]]
412-
; CHECK-NEXT: [[ELT0:%.*]] = extractelement <2 x double> [[VEC2]], i64 0
413-
; CHECK-NEXT: [[ELT1:%.*]] = extractelement <2 x double> [[VEC2]], i64 1
414-
; CHECK-NEXT: [[ADD1:%.*]] = fadd fast double [[ELT1]], [[ELT0]]
416+
; CHECK-NEXT: [[ADD1:%.*]] = call fast double @llvm.vector.reduce.fadd.v2f64(double 0.000000e+00, <2 x double> [[VEC2]])
415417
; CHECK-NEXT: ret double [[ADD1]]
416418
;
417419
entry:

llvm/test/Transforms/SLPVectorizer/AArch64/slp-fma-loss.ll

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -216,22 +216,16 @@ define void @slp_profitable_missing_fmf_nnans_only(ptr %A, ptr %B) {
216216
define float @slp_not_profitable_in_loop(float %x, ptr %A) {
217217
; CHECK-LABEL: @slp_not_profitable_in_loop(
218218
; CHECK-NEXT: entry:
219-
; CHECK-NEXT: [[GEP_A_2:%.*]] = getelementptr inbounds float, ptr [[A:%.*]], i64 2
220-
; CHECK-NEXT: [[L_1:%.*]] = load float, ptr [[GEP_A_2]], align 4
221-
; CHECK-NEXT: [[TMP0:%.*]] = load <2 x float>, ptr [[A]], align 4
219+
; CHECK-NEXT: [[TMP0:%.*]] = load <2 x float>, ptr [[A:%.*]], align 4
222220
; CHECK-NEXT: [[L_3:%.*]] = load float, ptr [[A]], align 4
223221
; CHECK-NEXT: [[TMP1:%.*]] = insertelement <2 x float> <float poison, float 3.000000e+00>, float [[X:%.*]], i32 0
224222
; CHECK-NEXT: br label [[LOOP:%.*]]
225223
; CHECK: loop:
226224
; CHECK-NEXT: [[IV:%.*]] = phi i64 [ 0, [[ENTRY:%.*]] ], [ [[IV_NEXT:%.*]], [[LOOP]] ]
227225
; CHECK-NEXT: [[RED:%.*]] = phi float [ 0.000000e+00, [[ENTRY]] ], [ [[RED_NEXT:%.*]], [[LOOP]] ]
228226
; CHECK-NEXT: [[TMP2:%.*]] = fmul fast <2 x float> [[TMP1]], [[TMP0]]
229-
; CHECK-NEXT: [[MUL12:%.*]] = fmul fast float 3.000000e+00, [[L_1]]
230227
; CHECK-NEXT: [[MUL16:%.*]] = fmul fast float 3.000000e+00, [[L_3]]
231-
; CHECK-NEXT: [[TMP3:%.*]] = extractelement <2 x float> [[TMP2]], i32 1
232-
; CHECK-NEXT: [[ADD:%.*]] = fadd fast float [[MUL12]], [[TMP3]]
233-
; CHECK-NEXT: [[TMP4:%.*]] = extractelement <2 x float> [[TMP2]], i32 0
234-
; CHECK-NEXT: [[ADD13:%.*]] = fadd fast float [[ADD]], [[TMP4]]
228+
; CHECK-NEXT: [[ADD13:%.*]] = call fast float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP2]])
235229
; CHECK-NEXT: [[RED_NEXT]] = fadd fast float [[ADD13]], [[MUL16]]
236230
; CHECK-NEXT: [[IV_NEXT]] = add nuw nsw i64 [[IV]], 1
237231
; CHECK-NEXT: [[CMP:%.*]] = icmp eq i64 [[IV]], 10

llvm/test/Transforms/SLPVectorizer/RISCV/revec.ll

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -141,33 +141,21 @@ define ptr @test4() {
141141
; POWEROF2-NEXT: [[TMP1:%.*]] = fadd <8 x float> zeroinitializer, zeroinitializer
142142
; POWEROF2-NEXT: [[TMP2:%.*]] = shufflevector <8 x float> [[TMP1]], <8 x float> poison, <2 x i32> <i32 1, i32 2>
143143
; POWEROF2-NEXT: [[TMP3:%.*]] = shufflevector <8 x float> [[TMP1]], <8 x float> poison, <2 x i32> <i32 5, i32 6>
144-
; POWEROF2-NEXT: [[TMP4:%.*]] = shufflevector <8 x float> [[TMP1]], <8 x float> poison, <2 x i32> <i32 0, i32 4>
145144
; POWEROF2-NEXT: [[TMP5:%.*]] = call <4 x float> @llvm.vector.insert.v4f32.v2f32(<4 x float> poison, <2 x float> [[TMP2]], i64 0)
146145
; POWEROF2-NEXT: [[TMP6:%.*]] = call <4 x float> @llvm.vector.insert.v4f32.v2f32(<4 x float> [[TMP5]], <2 x float> [[TMP3]], i64 2)
147146
; POWEROF2-NEXT: br label [[TMP8:%.*]]
148-
; POWEROF2: 7:
147+
; POWEROF2: 6:
149148
; POWEROF2-NEXT: br label [[TMP8]]
150-
; POWEROF2: 8:
151-
; POWEROF2-NEXT: [[TMP9:%.*]] = phi <2 x float> [ poison, [[TMP7:%.*]] ], [ [[TMP4]], [[TMP0:%.*]] ]
152-
; POWEROF2-NEXT: [[TMP10:%.*]] = phi <4 x float> [ poison, [[TMP7]] ], [ [[TMP6]], [[TMP0]] ]
149+
; POWEROF2: 7:
150+
; POWEROF2-NEXT: [[TMP10:%.*]] = phi <4 x float> [ poison, [[TMP7:%.*]] ], [ [[TMP6]], [[TMP0:%.*]] ]
153151
; POWEROF2-NEXT: br label [[TMP11:%.*]]
154-
; POWEROF2: 11:
152+
; POWEROF2: 9:
155153
; POWEROF2-NEXT: [[TMP12:%.*]] = call <2 x float> @llvm.vector.extract.v2f32.v4f32(<4 x float> [[TMP10]], i64 0)
156154
; POWEROF2-NEXT: [[TMP13:%.*]] = fmul <2 x float> [[TMP12]], zeroinitializer
157155
; POWEROF2-NEXT: [[TMP14:%.*]] = call <2 x float> @llvm.vector.extract.v2f32.v4f32(<4 x float> [[TMP10]], i64 2)
158156
; POWEROF2-NEXT: [[TMP15:%.*]] = fmul <2 x float> zeroinitializer, [[TMP14]]
159-
; POWEROF2-NEXT: [[TMP18:%.*]] = extractelement <2 x float> [[TMP9]], i32 0
160-
; POWEROF2-NEXT: [[TMP17:%.*]] = fmul float 0.000000e+00, [[TMP18]]
161-
; POWEROF2-NEXT: [[TMP30:%.*]] = extractelement <2 x float> [[TMP9]], i32 1
162-
; POWEROF2-NEXT: [[TMP19:%.*]] = fmul float [[TMP30]], 0.000000e+00
163-
; POWEROF2-NEXT: [[TMP20:%.*]] = extractelement <2 x float> [[TMP13]], i32 0
164-
; POWEROF2-NEXT: [[TMP21:%.*]] = fadd reassoc nsz float [[TMP20]], [[TMP17]]
165-
; POWEROF2-NEXT: [[TMP22:%.*]] = extractelement <2 x float> [[TMP15]], i32 0
166-
; POWEROF2-NEXT: [[TMP23:%.*]] = fadd reassoc nsz float [[TMP22]], [[TMP19]]
167-
; POWEROF2-NEXT: [[TMP24:%.*]] = extractelement <2 x float> [[TMP13]], i32 1
168-
; POWEROF2-NEXT: [[TMP25:%.*]] = fadd reassoc nsz float [[TMP21]], [[TMP24]]
169-
; POWEROF2-NEXT: [[TMP26:%.*]] = extractelement <2 x float> [[TMP15]], i32 1
170-
; POWEROF2-NEXT: [[TMP27:%.*]] = fadd reassoc nsz float [[TMP23]], [[TMP26]]
157+
; POWEROF2-NEXT: [[TMP25:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP13]])
158+
; POWEROF2-NEXT: [[TMP27:%.*]] = call reassoc nsz float @llvm.vector.reduce.fadd.v2f32(float 0.000000e+00, <2 x float> [[TMP15]])
171159
; POWEROF2-NEXT: [[TMP28:%.*]] = tail call float @llvm.sqrt.f32(float [[TMP25]])
172160
; POWEROF2-NEXT: [[TMP29:%.*]] = tail call float @llvm.sqrt.f32(float [[TMP27]])
173161
; POWEROF2-NEXT: ret ptr null

0 commit comments

Comments
 (0)