@@ -830,7 +830,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
830
830
// We have some custom DAG combine patterns for these nodes
831
831
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
832
832
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
833
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
833
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
834
834
835
835
// setcc for f16x2 and bf16x2 needs special handling to prevent
836
836
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5693,6 +5693,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5693
5693
return SDValue ();
5694
5694
}
5695
5695
5696
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5697
+ TargetLowering::DAGCombinerInfo &DCI) {
5698
+ SDLoc DL (N);
5699
+ SDValue Op = N->getOperand (0 );
5700
+ SDValue Trunc = N->getOperand (1 );
5701
+ EVT NarrowVT = N->getValueType (0 );
5702
+ EVT WideVT = Op.getValueType ();
5703
+
5704
+ // v2[b]f16 = fp_round (v2f32 A)
5705
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5706
+ // ([b]f16 = fp_round (extractelt A, 1)))
5707
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5708
+ WideVT == MVT::v2f32) {
5709
+ SDValue F32Op0, F32Op1;
5710
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5711
+ F32Op0 = Op.getOperand (0 );
5712
+ F32Op1 = Op.getOperand (1 );
5713
+ } else {
5714
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5715
+ DCI.DAG .getIntPtrConstant (0 , DL));
5716
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5717
+ DCI.DAG .getIntPtrConstant (1 , DL));
5718
+ }
5719
+ return DCI.DAG .getBuildVector (
5720
+ NarrowVT, DL,
5721
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5722
+ Trunc),
5723
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5724
+ Trunc)});
5725
+ }
5726
+
5727
+ return SDValue ();
5728
+ }
5729
+
5696
5730
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5697
5731
DAGCombinerInfo &DCI) const {
5698
5732
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5729,6 +5763,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5729
5763
return PerformBUILD_VECTORCombine (N, DCI);
5730
5764
case ISD::ADDRSPACECAST:
5731
5765
return combineADDRSPACECAST (N, DCI);
5766
+ case ISD::FP_ROUND:
5767
+ return PerformFP_ROUNDCombine (N, DCI);
5732
5768
}
5733
5769
return SDValue ();
5734
5770
}
0 commit comments