Skip to content

Commit 98e5962

Browse files
authored
[RISCV][CostModel] Add cost for fabs/fsqrt of type bf16/f16 (llvm#118608)
1 parent 46ca6df commit 98e5962

File tree

3 files changed

+243
-163
lines changed

3 files changed

+243
-163
lines changed

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "llvm/CodeGen/BasicTTIImpl.h"
1414
#include "llvm/CodeGen/CostTable.h"
1515
#include "llvm/CodeGen/TargetLowering.h"
16+
#include "llvm/CodeGen/ValueTypes.h"
1617
#include "llvm/IR/Instructions.h"
1718
#include "llvm/IR/PatternMatch.h"
1819
#include <cmath>
@@ -1035,21 +1036,66 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
10351036
}
10361037
break;
10371038
}
1038-
case Intrinsic::fabs:
1039+
case Intrinsic::fabs: {
1040+
auto LT = getTypeLegalizationCost(RetTy);
1041+
if (ST->hasVInstructions() && LT.second.isVector()) {
1042+
// lui a0, 8
1043+
// addi a0, a0, -1
1044+
// vsetvli a1, zero, e16, m1, ta, ma
1045+
// vand.vx v8, v8, a0
1046+
// f16 with zvfhmin and bf16 with zvfhbmin
1047+
if (LT.second.getVectorElementType() == MVT::bf16 ||
1048+
(LT.second.getVectorElementType() == MVT::f16 &&
1049+
!ST->hasVInstructionsF16()))
1050+
return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second,
1051+
CostKind) +
1052+
2;
1053+
else
1054+
return LT.first *
1055+
getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind);
1056+
}
1057+
break;
1058+
}
10391059
case Intrinsic::sqrt: {
10401060
auto LT = getTypeLegalizationCost(RetTy);
1041-
// TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin
10421061
if (ST->hasVInstructions() && LT.second.isVector()) {
1043-
unsigned Op;
1044-
switch (ICA.getID()) {
1045-
case Intrinsic::fabs:
1046-
Op = RISCV::VFSGNJX_VV;
1047-
break;
1048-
case Intrinsic::sqrt:
1049-
Op = RISCV::VFSQRT_V;
1050-
break;
1062+
SmallVector<unsigned, 4> ConvOp;
1063+
SmallVector<unsigned, 2> FsqrtOp;
1064+
MVT ConvType = LT.second;
1065+
MVT FsqrtType = LT.second;
1066+
// f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16
1067+
// will be spilt.
1068+
if (LT.second.getVectorElementType() == MVT::bf16) {
1069+
if (LT.second == MVT::nxv32bf16) {
1070+
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V,
1071+
RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W};
1072+
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
1073+
ConvType = MVT::nxv16f16;
1074+
FsqrtType = MVT::nxv16f32;
1075+
} else {
1076+
ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFNCVTBF16_F_F_W};
1077+
FsqrtOp = {RISCV::VFSQRT_V};
1078+
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
1079+
}
1080+
} else if (LT.second.getVectorElementType() == MVT::f16 &&
1081+
!ST->hasVInstructionsF16()) {
1082+
if (LT.second == MVT::nxv32f16) {
1083+
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_F_F_V,
1084+
RISCV::VFNCVT_F_F_W, RISCV::VFNCVT_F_F_W};
1085+
FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V};
1086+
ConvType = MVT::nxv16f16;
1087+
FsqrtType = MVT::nxv16f32;
1088+
} else {
1089+
ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFNCVT_F_F_W};
1090+
FsqrtOp = {RISCV::VFSQRT_V};
1091+
FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType);
1092+
}
1093+
} else {
1094+
FsqrtOp = {RISCV::VFSQRT_V};
10511095
}
1052-
return LT.first * getRISCVInstructionCost(Op, LT.second, CostKind);
1096+
1097+
return LT.first * (getRISCVInstructionCost(FsqrtOp, FsqrtType, CostKind) +
1098+
getRISCVInstructionCost(ConvOp, ConvType, CostKind));
10531099
}
10541100
break;
10551101
}

0 commit comments

Comments
 (0)