|
13 | 13 | #include "llvm/CodeGen/BasicTTIImpl.h"
|
14 | 14 | #include "llvm/CodeGen/CostTable.h"
|
15 | 15 | #include "llvm/CodeGen/TargetLowering.h"
|
| 16 | +#include "llvm/CodeGen/ValueTypes.h" |
16 | 17 | #include "llvm/IR/Instructions.h"
|
17 | 18 | #include "llvm/IR/PatternMatch.h"
|
18 | 19 | #include <cmath>
|
@@ -1035,21 +1036,66 @@ RISCVTTIImpl::getIntrinsicInstrCost(const IntrinsicCostAttributes &ICA,
|
1035 | 1036 | }
|
1036 | 1037 | break;
|
1037 | 1038 | }
|
1038 |
| - case Intrinsic::fabs: |
| 1039 | + case Intrinsic::fabs: { |
| 1040 | + auto LT = getTypeLegalizationCost(RetTy); |
| 1041 | + if (ST->hasVInstructions() && LT.second.isVector()) { |
| 1042 | + // lui a0, 8 |
| 1043 | + // addi a0, a0, -1 |
| 1044 | + // vsetvli a1, zero, e16, m1, ta, ma |
| 1045 | + // vand.vx v8, v8, a0 |
| 1046 | + // f16 with zvfhmin and bf16 with zvfhbmin |
| 1047 | + if (LT.second.getVectorElementType() == MVT::bf16 || |
| 1048 | + (LT.second.getVectorElementType() == MVT::f16 && |
| 1049 | + !ST->hasVInstructionsF16())) |
| 1050 | + return LT.first * getRISCVInstructionCost(RISCV::VAND_VX, LT.second, |
| 1051 | + CostKind) + |
| 1052 | + 2; |
| 1053 | + else |
| 1054 | + return LT.first * |
| 1055 | + getRISCVInstructionCost(RISCV::VFSGNJX_VV, LT.second, CostKind); |
| 1056 | + } |
| 1057 | + break; |
| 1058 | + } |
1039 | 1059 | case Intrinsic::sqrt: {
|
1040 | 1060 | auto LT = getTypeLegalizationCost(RetTy);
|
1041 |
| - // TODO: add f16/bf16, bf16 with zvfbfmin && f16 with zvfhmin |
1042 | 1061 | if (ST->hasVInstructions() && LT.second.isVector()) {
|
1043 |
| - unsigned Op; |
1044 |
| - switch (ICA.getID()) { |
1045 |
| - case Intrinsic::fabs: |
1046 |
| - Op = RISCV::VFSGNJX_VV; |
1047 |
| - break; |
1048 |
| - case Intrinsic::sqrt: |
1049 |
| - Op = RISCV::VFSQRT_V; |
1050 |
| - break; |
| 1062 | + SmallVector<unsigned, 4> ConvOp; |
| 1063 | + SmallVector<unsigned, 2> FsqrtOp; |
| 1064 | + MVT ConvType = LT.second; |
| 1065 | + MVT FsqrtType = LT.second; |
| 1066 | + // f16 with zvfhmin and bf16 with zvfbfmin and the type of nxv32[b]f16 |
| 1067 | + // will be spilt. |
| 1068 | + if (LT.second.getVectorElementType() == MVT::bf16) { |
| 1069 | + if (LT.second == MVT::nxv32bf16) { |
| 1070 | + ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFWCVTBF16_F_F_V, |
| 1071 | + RISCV::VFNCVTBF16_F_F_W, RISCV::VFNCVTBF16_F_F_W}; |
| 1072 | + FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V}; |
| 1073 | + ConvType = MVT::nxv16f16; |
| 1074 | + FsqrtType = MVT::nxv16f32; |
| 1075 | + } else { |
| 1076 | + ConvOp = {RISCV::VFWCVTBF16_F_F_V, RISCV::VFNCVTBF16_F_F_W}; |
| 1077 | + FsqrtOp = {RISCV::VFSQRT_V}; |
| 1078 | + FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType); |
| 1079 | + } |
| 1080 | + } else if (LT.second.getVectorElementType() == MVT::f16 && |
| 1081 | + !ST->hasVInstructionsF16()) { |
| 1082 | + if (LT.second == MVT::nxv32f16) { |
| 1083 | + ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFWCVT_F_F_V, |
| 1084 | + RISCV::VFNCVT_F_F_W, RISCV::VFNCVT_F_F_W}; |
| 1085 | + FsqrtOp = {RISCV::VFSQRT_V, RISCV::VFSQRT_V}; |
| 1086 | + ConvType = MVT::nxv16f16; |
| 1087 | + FsqrtType = MVT::nxv16f32; |
| 1088 | + } else { |
| 1089 | + ConvOp = {RISCV::VFWCVT_F_F_V, RISCV::VFNCVT_F_F_W}; |
| 1090 | + FsqrtOp = {RISCV::VFSQRT_V}; |
| 1091 | + FsqrtType = TLI->getTypeToPromoteTo(ISD::FSQRT, FsqrtType); |
| 1092 | + } |
| 1093 | + } else { |
| 1094 | + FsqrtOp = {RISCV::VFSQRT_V}; |
1051 | 1095 | }
|
1052 |
| - return LT.first * getRISCVInstructionCost(Op, LT.second, CostKind); |
| 1096 | + |
| 1097 | + return LT.first * (getRISCVInstructionCost(FsqrtOp, FsqrtType, CostKind) + |
| 1098 | + getRISCVInstructionCost(ConvOp, ConvType, CostKind)); |
1053 | 1099 | }
|
1054 | 1100 | break;
|
1055 | 1101 | }
|
|
0 commit comments