@@ -862,7 +862,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
862
862
// We have some custom DAG combine patterns for these nodes
863
863
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
864
864
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
865
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
865
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
866
866
867
867
// setcc for f16x2 and bf16x2 needs special handling to prevent
868
868
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5813,6 +5813,46 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5813
5813
return SDValue ();
5814
5814
}
5815
5815
5816
+ // Combiner rule for v2[b]f16 = fp_round v2f32:
5817
+ //
5818
+ // Now that v2f32 is a legal type for a register, this node will go straight to
5819
+ // instruction selection. Instead, we want to break it up into two nodes, which
5820
+ // can be combined in instruction selection to cvt.[b]f16x2.f32, which requires
5821
+ // two f32 registers.
5822
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5823
+ TargetLowering::DAGCombinerInfo &DCI) {
5824
+ SDLoc DL (N);
5825
+ SDValue Op = N->getOperand (0 );
5826
+ SDValue Trunc = N->getOperand (1 );
5827
+ EVT NarrowVT = N->getValueType (0 );
5828
+ EVT WideVT = Op.getValueType ();
5829
+
5830
+ // v2[b]f16 = fp_round (v2f32 A)
5831
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5832
+ // ([b]f16 = fp_round (extractelt A, 1)))
5833
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5834
+ WideVT == MVT::v2f32) {
5835
+ SDValue F32Op0, F32Op1;
5836
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5837
+ F32Op0 = Op.getOperand (0 );
5838
+ F32Op1 = Op.getOperand (1 );
5839
+ } else {
5840
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5841
+ DCI.DAG .getIntPtrConstant (0 , DL));
5842
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5843
+ DCI.DAG .getIntPtrConstant (1 , DL));
5844
+ }
5845
+ return DCI.DAG .getBuildVector (
5846
+ NarrowVT, DL,
5847
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5848
+ Trunc),
5849
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5850
+ Trunc)});
5851
+ }
5852
+
5853
+ return SDValue ();
5854
+ }
5855
+
5816
5856
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5817
5857
DAGCombinerInfo &DCI) const {
5818
5858
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5849,6 +5889,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5849
5889
return PerformBUILD_VECTORCombine (N, DCI);
5850
5890
case ISD::ADDRSPACECAST:
5851
5891
return combineADDRSPACECAST (N, DCI);
5892
+ case ISD::FP_ROUND:
5893
+ return PerformFP_ROUNDCombine (N, DCI);
5852
5894
}
5853
5895
return SDValue ();
5854
5896
}
0 commit comments