@@ -7012,39 +7012,52 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
7012
7012
7013
7013
// Find the cost of vectorizing the call, if we can find a suitable
7014
7014
// vector variant of the function.
7015
- InstructionCost MaskCost = 0 ;
7016
- VFShape Shape = VFShape::get (*CI, VF, MaskRequired);
7017
- bool UsesMask = MaskRequired;
7018
- Function *VecFunc = VFDatabase (*CI).getVectorizedFunction (Shape);
7019
- // If we want an unmasked vector function but can't find one matching the
7020
- // VF, maybe we can find vector function that does use a mask and
7021
- // synthesize an all-true mask.
7022
- if (!VecFunc && !MaskRequired) {
7023
- Shape = VFShape::get (*CI, VF, /* HasGlobalPred=*/ true );
7024
- VecFunc = VFDatabase (*CI).getVectorizedFunction (Shape);
7025
- // If we found one, add in the cost of creating a mask
7026
- if (VecFunc) {
7027
- UsesMask = true ;
7028
- MaskCost = TTI.getShuffleCost (
7029
- TargetTransformInfo::SK_Broadcast,
7030
- VectorType::get (IntegerType::getInt1Ty (
7031
- VecFunc->getFunctionType ()->getContext ()),
7032
- VF));
7033
- }
7034
- }
7015
+ bool UsesMask = false ;
7016
+ VFInfo FuncInfo;
7017
+ Function *VecFunc = nullptr ;
7018
+ // Search through any available variants for one we can use at this VF.
7019
+ for (VFInfo &Info : VFDatabase::getMappings (*CI)) {
7020
+ // Must match requested VF.
7021
+ if (Info.Shape .VF != VF)
7022
+ continue ;
7035
7023
7036
- std::optional<unsigned > MaskPos = std::nullopt;
7037
- if (VecFunc && UsesMask) {
7038
- for (const VFInfo &Info : VFDatabase::getMappings (*CI))
7039
- if (Info.Shape == Shape) {
7040
- assert (Info.isMasked () && " Vector function info shape mismatch" );
7041
- MaskPos = Info.getParamIndexForOptionalMask ().value ();
7024
+ // Must take a mask argument if one is required
7025
+ if (MaskRequired && !Info.isMasked ())
7026
+ continue ;
7027
+
7028
+ // Check that all parameter kinds are supported
7029
+ bool ParamsOk = true ;
7030
+ for (VFParameter Param : Info.Shape .Parameters ) {
7031
+ switch (Param.ParamKind ) {
7032
+ case VFParamKind::Vector:
7033
+ break ;
7034
+ case VFParamKind::GlobalPredicate:
7035
+ UsesMask = true ;
7036
+ break ;
7037
+ default :
7038
+ ParamsOk = false ;
7042
7039
break ;
7043
7040
}
7041
+ }
7042
+
7043
+ if (!ParamsOk)
7044
+ continue ;
7044
7045
7045
- assert (MaskPos.has_value () && " Unable to find mask parameter index" );
7046
+ // Found a suitable candidate, stop here.
7047
+ VecFunc = CI->getModule ()->getFunction (Info.VectorName );
7048
+ FuncInfo = Info;
7049
+ break ;
7046
7050
}
7047
7051
7052
+ // Add in the cost of synthesizing a mask if one wasn't required.
7053
+ InstructionCost MaskCost = 0 ;
7054
+ if (VecFunc && UsesMask && !MaskRequired)
7055
+ MaskCost = TTI.getShuffleCost (
7056
+ TargetTransformInfo::SK_Broadcast,
7057
+ VectorType::get (IntegerType::getInt1Ty (
7058
+ VecFunc->getFunctionType ()->getContext ()),
7059
+ VF));
7060
+
7048
7061
if (TLI && VecFunc && !CI->isNoBuiltin ())
7049
7062
VectorCost =
7050
7063
TTI.getCallInstrCost (nullptr , RetTy, Tys, CostKind) + MaskCost;
@@ -7068,7 +7081,8 @@ void LoopVectorizationCostModel::setVectorizedCallDecision(ElementCount VF) {
7068
7081
Decision = CM_IntrinsicCall;
7069
7082
}
7070
7083
7071
- setCallWideningDecision (CI, VF, Decision, VecFunc, IID, MaskPos, Cost);
7084
+ setCallWideningDecision (CI, VF, Decision, VecFunc, IID,
7085
+ FuncInfo.getParamIndexForOptionalMask (), Cost);
7072
7086
}
7073
7087
}
7074
7088
}
0 commit comments