Skip to content

Commit 6ca83cc

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent e016449 commit 6ca83cc

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
825825
// We have some custom DAG combine patterns for these nodes
826826
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
827827
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
828-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
828+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
829829

830830
// setcc for f16x2 and bf16x2 needs special handling to prevent
831831
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5582,6 +5582,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
55825582
return SDValue();
55835583
}
55845584

5585+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5586+
TargetLowering::DAGCombinerInfo &DCI) {
5587+
SDLoc DL(N);
5588+
SDValue Op = N->getOperand(0);
5589+
SDValue Trunc = N->getOperand(1);
5590+
EVT NarrowVT = N->getValueType(0);
5591+
EVT WideVT = Op.getValueType();
5592+
5593+
// v2[b]f16 = fp_round (v2f32 A)
5594+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5595+
// ([b]f16 = fp_round (extractelt A, 1)))
5596+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5597+
WideVT == MVT::v2f32) {
5598+
SDValue F32Op0, F32Op1;
5599+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5600+
F32Op0 = Op.getOperand(0);
5601+
F32Op1 = Op.getOperand(1);
5602+
} else {
5603+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5604+
DCI.DAG.getIntPtrConstant(0, DL));
5605+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5606+
DCI.DAG.getIntPtrConstant(1, DL));
5607+
}
5608+
return DCI.DAG.getBuildVector(
5609+
NarrowVT, DL,
5610+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5611+
Trunc),
5612+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5613+
Trunc)});
5614+
}
5615+
5616+
return SDValue();
5617+
}
5618+
55855619
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
55865620
DAGCombinerInfo &DCI) const {
55875621
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5618,6 +5652,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56185652
return PerformBUILD_VECTORCombine(N, DCI);
56195653
case ISD::ADDRSPACECAST:
56205654
return combineADDRSPACECAST(N, DCI);
5655+
case ISD::FP_ROUND:
5656+
return PerformFP_ROUNDCombine(N, DCI);
56215657
}
56225658
return SDValue();
56235659
}

0 commit comments

Comments
 (0)