Skip to content

Commit 4d64a2b

Browse files
authored
[LV] Refactor vector function variant selection to prepare for uniform args (#68879)
Parameters marked as uniform take a scalar value, assuming the value is invariant in the scalar loop. In order to support this, we need to stop asking for a vector function variant with a default shape assuming that all arguments will become vector arguments, and instead consider all available variants and their parameter types.
1 parent 1716c5b commit 4d64a2b

File tree

2 files changed

+52
-32
lines changed

2 files changed

+52
-32
lines changed

llvm/lib/Transforms/Vectorize/LoopVectorize.cpp

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -7012,39 +7012,52 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
70127012

70137013
// Find the cost of vectorizing the call, if we can find a suitable
70147014
// vector variant of the function.
7015-
InstructionCost MaskCost = 0;
7016-
VFShape Shape = VFShape::get(*CI, VF, MaskRequired);
7017-
bool UsesMask = MaskRequired;
7018-
Function *VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
7019-
// If we want an unmasked vector function but can't find one matching the
7020-
// VF, maybe we can find vector function that does use a mask and
7021-
// synthesize an all-true mask.
7022-
if (!VecFunc && !MaskRequired) {
7023-
Shape = VFShape::get(*CI, VF, /*HasGlobalPred=*/true);
7024-
VecFunc = VFDatabase(*CI).getVectorizedFunction(Shape);
7025-
// If we found one, add in the cost of creating a mask
7026-
if (VecFunc) {
7027-
UsesMask = true;
7028-
MaskCost = TTI.getShuffleCost(
7029-
TargetTransformInfo::SK_Broadcast,
7030-
VectorType::get(IntegerType::getInt1Ty(
7031-
VecFunc->getFunctionType()->getContext()),
7032-
VF));
7033-
}
7034-
}
7015+
bool UsesMask = false;
7016+
VFInfo FuncInfo;
7017+
Function *VecFunc = nullptr;
7018+
// Search through any available variants for one we can use at this VF.
7019+
for (VFInfo &Info : VFDatabase::getMappings(*CI)) {
7020+
// Must match requested VF.
7021+
if (Info.Shape.VF != VF)
7022+
continue;
70357023

7036-
std::optional<unsigned> MaskPos = std::nullopt;
7037-
if (VecFunc && UsesMask) {
7038-
for (const VFInfo &Info : VFDatabase::getMappings(*CI))
7039-
if (Info.Shape == Shape) {
7040-
assert(Info.isMasked() && "Vector function info shape mismatch");
7041-
MaskPos = Info.getParamIndexForOptionalMask().value();
7024+
// Must take a mask argument if one is required
7025+
if (MaskRequired && !Info.isMasked())
7026+
continue;
7027+
7028+
// Check that all parameter kinds are supported
7029+
bool ParamsOk = true;
7030+
for (VFParameter Param : Info.Shape.Parameters) {
7031+
switch (Param.ParamKind) {
7032+
case VFParamKind::Vector:
7033+
break;
7034+
case VFParamKind::GlobalPredicate:
7035+
UsesMask = true;
7036+
break;
7037+
default:
7038+
ParamsOk = false;
70427039
break;
70437040
}
7041+
}
7042+
7043+
if (!ParamsOk)
7044+
continue;
70447045

7045-
assert(MaskPos.has_value() && "Unable to find mask parameter index");
7046+
// Found a suitable candidate, stop here.
7047+
VecFunc = CI->getModule()->getFunction(Info.VectorName);
7048+
FuncInfo = Info;
7049+
break;
70467050
}
70477051

7052+
// Add in the cost of synthesizing a mask if one wasn't required.
7053+
InstructionCost MaskCost = 0;
7054+
if (VecFunc && UsesMask && !MaskRequired)
7055+
MaskCost = TTI.getShuffleCost(
7056+
TargetTransformInfo::SK_Broadcast,
7057+
VectorType::get(IntegerType::getInt1Ty(
7058+
VecFunc->getFunctionType()->getContext()),
7059+
VF));
7060+
70487061
if (TLI && VecFunc && !CI->isNoBuiltin())
70497062
VectorCost =
70507063
TTI.getCallInstrCost(nullptr, RetTy, Tys, CostKind) + MaskCost;
@@ -7068,7 +7081,8 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
70687081
Decision = CM_IntrinsicCall;
70697082
}
70707083

7071-
setCallWideningDecision(CI, VF, Decision, VecFunc, IID, MaskPos, Cost);
7084+
setCallWideningDecision(CI, VF, Decision, VecFunc, IID,
7085+
FuncInfo.getParamIndexForOptionalMask(), Cost);
70727086
}
70737087
}
70747088
}

llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -504,6 +504,9 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
504504
"DbgInfoIntrinsic should have been dropped during VPlan construction");
505505
State.setDebugLocFrom(CI.getDebugLoc());
506506

507+
FunctionType *VFTy = nullptr;
508+
if (Variant)
509+
VFTy = Variant->getFunctionType();
507510
for (unsigned Part = 0; Part < State.UF; ++Part) {
508511
SmallVector<Type *, 2> TysForDecl;
509512
// Add return type if intrinsic is overloaded on it.
@@ -515,12 +518,15 @@ void VPWidenCallRecipe::execute(VPTransformState &State) {
515518
for (const auto &I : enumerate(operands())) {
516519
// Some intrinsics have a scalar argument - don't replace it with a
517520
// vector.
521+
// Some vectorized function variants may also take a scalar argument,
522+
// e.g. linear parameters for pointers.
518523
Value *Arg;
519-
if (VectorIntrinsicID == Intrinsic::not_intrinsic ||
520-
!isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index()))
521-
Arg = State.get(I.value(), Part);
522-
else
524+
if ((VFTy && !VFTy->getParamType(I.index())->isVectorTy()) ||
525+
(VectorIntrinsicID != Intrinsic::not_intrinsic &&
526+
isVectorIntrinsicWithScalarOpAtArg(VectorIntrinsicID, I.index())))
523527
Arg = State.get(I.value(), VPIteration(0, 0));
528+
else
529+
Arg = State.get(I.value(), Part);
524530
if (isVectorIntrinsicWithOverloadTypeAtArg(VectorIntrinsicID, I.index()))
525531
TysForDecl.push_back(Arg->getType());
526532
Args.push_back(Arg);

0 commit comments

Comments
 (0)