Skip to content

Commit 4ed8f9f

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent baef6cd commit 4ed8f9f

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
830830
// We have some custom DAG combine patterns for these nodes
831831
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
832832
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
833-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
833+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
834834

835835
// setcc for f16x2 and bf16x2 needs special handling to prevent
836836
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5693,6 +5693,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
56935693
return SDValue();
56945694
}
56955695

5696+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5697+
TargetLowering::DAGCombinerInfo &DCI) {
5698+
SDLoc DL(N);
5699+
SDValue Op = N->getOperand(0);
5700+
SDValue Trunc = N->getOperand(1);
5701+
EVT NarrowVT = N->getValueType(0);
5702+
EVT WideVT = Op.getValueType();
5703+
5704+
// v2[b]f16 = fp_round (v2f32 A)
5705+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5706+
// ([b]f16 = fp_round (extractelt A, 1)))
5707+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5708+
WideVT == MVT::v2f32) {
5709+
SDValue F32Op0, F32Op1;
5710+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5711+
F32Op0 = Op.getOperand(0);
5712+
F32Op1 = Op.getOperand(1);
5713+
} else {
5714+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5715+
DCI.DAG.getIntPtrConstant(0, DL));
5716+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5717+
DCI.DAG.getIntPtrConstant(1, DL));
5718+
}
5719+
return DCI.DAG.getBuildVector(
5720+
NarrowVT, DL,
5721+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5722+
Trunc),
5723+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5724+
Trunc)});
5725+
}
5726+
5727+
return SDValue();
5728+
}
5729+
56965730
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56975731
DAGCombinerInfo &DCI) const {
56985732
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5729,6 +5763,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57295763
return PerformBUILD_VECTORCombine(N, DCI);
57305764
case ISD::ADDRSPACECAST:
57315765
return combineADDRSPACECAST(N, DCI);
5766+
case ISD::FP_ROUND:
5767+
return PerformFP_ROUNDCombine(N, DCI);
57325768
}
57335769
return SDValue();
57345770
}

0 commit comments

Comments
 (0)