@@ -826,7 +826,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
826
826
// We have some custom DAG combine patterns for these nodes
827
827
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
828
828
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
829
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
829
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
830
830
831
831
// setcc for f16x2 and bf16x2 needs special handling to prevent
832
832
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5713,6 +5713,46 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5713
5713
return SDValue ();
5714
5714
}
5715
5715
5716
+ // Combiner rule for v2[b]f16 = fp_round v2f32:
5717
+ //
5718
+ // Now that v2f32 is a legal type for a register, this node will go straight to
5719
+ // instruction selection. Instead, we want to break it up into two nodes, which
5720
+ // can be combined in instruction selection to cvt.[b]f16x2.f32, which requires
5721
+ // two f32 registers.
5722
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5723
+ TargetLowering::DAGCombinerInfo &DCI) {
5724
+ SDLoc DL (N);
5725
+ SDValue Op = N->getOperand (0 );
5726
+ SDValue Trunc = N->getOperand (1 );
5727
+ EVT NarrowVT = N->getValueType (0 );
5728
+ EVT WideVT = Op.getValueType ();
5729
+
5730
+ // v2[b]f16 = fp_round (v2f32 A)
5731
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5732
+ // ([b]f16 = fp_round (extractelt A, 1)))
5733
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5734
+ WideVT == MVT::v2f32) {
5735
+ SDValue F32Op0, F32Op1;
5736
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5737
+ F32Op0 = Op.getOperand (0 );
5738
+ F32Op1 = Op.getOperand (1 );
5739
+ } else {
5740
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5741
+ DCI.DAG .getIntPtrConstant (0 , DL));
5742
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5743
+ DCI.DAG .getIntPtrConstant (1 , DL));
5744
+ }
5745
+ return DCI.DAG .getBuildVector (
5746
+ NarrowVT, DL,
5747
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5748
+ Trunc),
5749
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5750
+ Trunc)});
5751
+ }
5752
+
5753
+ return SDValue ();
5754
+ }
5755
+
5716
5756
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5717
5757
DAGCombinerInfo &DCI) const {
5718
5758
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5749,6 +5789,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5749
5789
return PerformBUILD_VECTORCombine (N, DCI);
5750
5790
case ISD::ADDRSPACECAST:
5751
5791
return combineADDRSPACECAST (N, DCI);
5792
+ case ISD::FP_ROUND:
5793
+ return PerformFP_ROUNDCombine (N, DCI);
5752
5794
}
5753
5795
return SDValue ();
5754
5796
}
0 commit comments