Skip to content

Commit 3f02980

Browse files
committed
[NVPTX] add combiner rule for v2[b]f16 = fp_round v2f32
Now that v2f32 is legal, this node will go straight to instruction selection. Instead, we want to break it up into two nodes, which can be handled better in instruction selection, since the final instruction (cvt.[b]f16x2.f32) takes two f32 arguments.
1 parent 6396d62 commit 3f02980

File tree

1 file changed

+37
-1
lines changed

1 file changed

+37
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -830,7 +830,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
830830
// We have some custom DAG combine patterns for these nodes
831831
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
832832
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
833-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
833+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
834834

835835
// setcc for f16x2 and bf16x2 needs special handling to prevent
836836
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5695,6 +5695,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
56955695
return SDValue();
56965696
}
56975697

5698+
static SDValue PerformFP_ROUNDCombine(SDNode *N,
5699+
TargetLowering::DAGCombinerInfo &DCI) {
5700+
SDLoc DL(N);
5701+
SDValue Op = N->getOperand(0);
5702+
SDValue Trunc = N->getOperand(1);
5703+
EVT NarrowVT = N->getValueType(0);
5704+
EVT WideVT = Op.getValueType();
5705+
5706+
// v2[b]f16 = fp_round (v2f32 A)
5707+
// -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5708+
// ([b]f16 = fp_round (extractelt A, 1)))
5709+
if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5710+
WideVT == MVT::v2f32) {
5711+
SDValue F32Op0, F32Op1;
5712+
if (Op.getOpcode() == ISD::BUILD_VECTOR) {
5713+
F32Op0 = Op.getOperand(0);
5714+
F32Op1 = Op.getOperand(1);
5715+
} else {
5716+
F32Op0 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5717+
DCI.DAG.getIntPtrConstant(0, DL));
5718+
F32Op1 = DCI.DAG.getNode(ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32, Op,
5719+
DCI.DAG.getIntPtrConstant(1, DL));
5720+
}
5721+
return DCI.DAG.getBuildVector(
5722+
NarrowVT, DL,
5723+
{DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op0,
5724+
Trunc),
5725+
DCI.DAG.getNode(ISD::FP_ROUND, DL, NarrowVT.getScalarType(), F32Op1,
5726+
Trunc)});
5727+
}
5728+
5729+
return SDValue();
5730+
}
5731+
56985732
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56995733
DAGCombinerInfo &DCI) const {
57005734
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5731,6 +5765,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57315765
return PerformBUILD_VECTORCombine(N, DCI);
57325766
case ISD::ADDRSPACECAST:
57335767
return combineADDRSPACECAST(N, DCI);
5768+
case ISD::FP_ROUND:
5769+
return PerformFP_ROUNDCombine(N, DCI);
57345770
}
57355771
return SDValue();
57365772
}

0 commit comments

Comments
 (0)