Skip to content

Commit b35a635

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent 0f3fdc2 commit b35a635

File tree

1 file changed

+43
-1
lines changed

1 file changed

+43
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,7 +862,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862862
// We have some custom DAG combine patterns for these nodes
863863
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
864864
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
865-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
865+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
866866

867867
// setcc for f16x2 and bf16x2 needs special handling to prevent
868868
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5813,6 +5813,46 @@ static SDValue combineADDRSPACECAST(SDNode *N,
58135813
return SDValue();
58145814
}
58155815

5816+
// Combiner rule for v2[b]f16 = fp_round v2f32:
5817+
//
5818+
// Now that v2f32 is a legal type for a register, this node will go straight to
5819+
// instruction selection. Instead, we want to break it up into two nodes, which
5820+
// can be combined in instruction selection to cvt.[b]f16x2.f32, which requires
5821+
// two f32 registers.
5822+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5823+
TargetLowering::DAGCombinerInfo &DCI) {
5824+
SDLoc DL(N);
5825+
SDValue Op = N->getOperand(0);
5826+
SDValue Trunc = N->getOperand(1);
5827+
EVT NarrowVT = N->getValueType(0);
5828+
EVT WideVT = Op.getValueType();
5829+
5830+
// v2[b]f16 = fp_round (v2f32 A)
5831+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5832+
// ([b]f16 = fp_round (extractelt A, 1)))
5833+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5834+
WideVT == MVT::v2f32) {
5835+
SDValue F32Op0, F32Op1;
5836+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5837+
F32Op0 = Op.getOperand(0);
5838+
F32Op1 = Op.getOperand(1);
5839+
} else {
5840+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5841+
DCI.DAG.getIntPtrConstant(0, DL));
5842+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5843+
DCI.DAG.getIntPtrConstant(1, DL));
5844+
}
5845+
return DCI.DAG.getBuildVector(
5846+
NarrowVT, DL,
5847+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5848+
Trunc),
5849+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5850+
Trunc)});
5851+
}
5852+
5853+
return SDValue();
5854+
}
5855+
58165856
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58175857
DAGCombinerInfo &DCI) const {
58185858
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5849,6 +5889,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58495889
return PerformBUILD_VECTORCombine(N, DCI);
58505890
case ISD::ADDRSPACECAST:
58515891
return combineADDRSPACECAST(N, DCI);
5892+
case ISD::FP_ROUND:
5893+
return PerformFP_ROUNDCombine(N, DCI);
58525894
}
58535895
return SDValue();
58545896
}

0 commit comments

Comments
 (0)