Skip to content

Commit 1a2fa58

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent 04b51f8 commit 1a2fa58

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
826826
// We have some custom DAG combine patterns for these nodes
827827
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
828828
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
829-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
829+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
830830

831831
// setcc for f16x2 and bf16x2 needs special handling to prevent
832832
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5713,6 +5713,46 @@ static SDValue combineADDRSPACECAST(SDNode *N,
57135713
return SDValue();
57145714
}
57155715

5716+
// Combiner rule for v2[b]f16 = fp_round v2f32:
5717+
//
5718+
// Now that v2f32 is a legal type for a register, this node will go straight to
5719+
// instruction selection. Instead, we want to break it up into two nodes, which
5720+
// can be combined in instruction selection to cvt.[b]f16x2.f32, which requires
5721+
// two f32 registers.
5722+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5723+
TargetLowering::DAGCombinerInfo &DCI) {
5724+
SDLoc DL(N);
5725+
SDValue Op = N->getOperand(0);
5726+
SDValue Trunc = N->getOperand(1);
5727+
EVT NarrowVT = N->getValueType(0);
5728+
EVT WideVT = Op.getValueType();
5729+
5730+
// v2[b]f16 = fp_round (v2f32 A)
5731+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5732+
// ([b]f16 = fp_round (extractelt A, 1)))
5733+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5734+
WideVT == MVT::v2f32) {
5735+
SDValue F32Op0, F32Op1;
5736+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5737+
F32Op0 = Op.getOperand(0);
5738+
F32Op1 = Op.getOperand(1);
5739+
} else {
5740+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5741+
DCI.DAG.getIntPtrConstant(0, DL));
5742+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5743+
DCI.DAG.getIntPtrConstant(1, DL));
5744+
}
5745+
return DCI.DAG.getBuildVector(
5746+
NarrowVT, DL,
5747+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5748+
Trunc),
5749+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5750+
Trunc)});
5751+
}
5752+
5753+
return SDValue();
5754+
}
5755+
57165756
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57175757
DAGCombinerInfo &DCI) const {
57185758
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5749,6 +5789,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57495789
return PerformBUILD_VECTORCombine(N, DCI);
57505790
case ISD::ADDRSPACECAST:
57515791
return combineADDRSPACECAST(N, DCI);
5792+
case ISD::FP_ROUND:
5793+
return PerformFP_ROUNDCombine(N, DCI);
57525794
}
57535795
return SDValue();
57545796
}

0 commit comments

Comments
 (0)