Skip to content

Commit 9c97b38

Browse files
authored
[ISel/RISCV] Custom-promote [b]f16 in [l]lrint (#146507)
Extend lowerVectorXRINT to also do a FP_EXTEND_VL when the source element type is [b]f16, and wire up this custom-promote. Updating the cost-model to not give these an invalid cost is left to a companion patch.
1 parent edaf656 commit 9c97b38

File tree

5 files changed

+1165
-134
lines changed

5 files changed

+1165
-134
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1150,6 +1150,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
11501150
setOperationAction({ISD::STRICT_FP_ROUND, ISD::STRICT_FP_EXTEND}, VT,
11511151
Custom);
11521152
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1153+
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
11531154
setOperationAction({ISD::VP_MERGE, ISD::VP_SELECT, ISD::SELECT}, VT,
11541155
Custom);
11551156
setOperationAction(ISD::SELECT_CC, VT, Expand);
@@ -1451,6 +1452,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
14511452
Custom);
14521453
setOperationAction({ISD::VP_SINT_TO_FP, ISD::VP_UINT_TO_FP}, VT,
14531454
Custom);
1455+
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
14541456
if (Subtarget.hasStdExtZfhmin()) {
14551457
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
14561458
} else {
@@ -1475,6 +1477,7 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
14751477
if (VT.getVectorElementType() == MVT::bf16) {
14761478
setOperationAction(ISD::BITCAST, VT, Custom);
14771479
setOperationAction({ISD::VP_FP_ROUND, ISD::VP_FP_EXTEND}, VT, Custom);
1480+
setOperationAction({ISD::LRINT, ISD::LLRINT}, VT, Custom);
14781481
if (Subtarget.hasStdExtZfbfmin()) {
14791482
setOperationAction(ISD::BUILD_VECTOR, VT, Custom);
14801483
} else {
@@ -3487,6 +3490,14 @@ static SDValue lowerVectorXRINT(SDValue Op, SelectionDAG &DAG,
34873490
}
34883491

34893492
auto [Mask, VL] = getDefaultVLOps(SrcVT, SrcContainerVT, DL, DAG, Subtarget);
3493+
3494+
// [b]f16 -> f32
3495+
MVT SrcElemType = SrcVT.getVectorElementType();
3496+
if (SrcElemType == MVT::f16 || SrcElemType == MVT::bf16) {
3497+
MVT F32VT = SrcContainerVT.changeVectorElementType(MVT::f32);
3498+
Src = DAG.getNode(RISCVISD::FP_EXTEND_VL, DL, F32VT, Src, Mask, VL);
3499+
}
3500+
34903501
SDValue Res =
34913502
DAG.getNode(RISCVISD::VFCVT_RM_X_F_VL, DL, DstContainerVT, Src, Mask,
34923503
DAG.getTargetConstant(matchRoundingOp(Op.getOpcode()), DL,

0 commit comments

Comments
 (0)