@@ -830,7 +830,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
830
830
// We have some custom DAG combine patterns for these nodes
831
831
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
832
832
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
833
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
833
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
834
834
835
835
// setcc for f16x2 and bf16x2 needs special handling to prevent
836
836
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5695,6 +5695,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5695
5695
return SDValue ();
5696
5696
}
5697
5697
5698
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5699
+ TargetLowering::DAGCombinerInfo &DCI) {
5700
+ SDLoc DL (N);
5701
+ SDValue Op = N->getOperand (0 );
5702
+ SDValue Trunc = N->getOperand (1 );
5703
+ EVT NarrowVT = N->getValueType (0 );
5704
+ EVT WideVT = Op.getValueType ();
5705
+
5706
+ // v2[b]f16 = fp_round (v2f32 A)
5707
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5708
+ // ([b]f16 = fp_round (extractelt A, 1)))
5709
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5710
+ WideVT == MVT::v2f32) {
5711
+ SDValue F32Op0, F32Op1;
5712
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5713
+ F32Op0 = Op.getOperand (0 );
5714
+ F32Op1 = Op.getOperand (1 );
5715
+ } else {
5716
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5717
+ DCI.DAG .getIntPtrConstant (0 , DL));
5718
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5719
+ DCI.DAG .getIntPtrConstant (1 , DL));
5720
+ }
5721
+ return DCI.DAG .getBuildVector (
5722
+ NarrowVT, DL,
5723
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5724
+ Trunc),
5725
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5726
+ Trunc)});
5727
+ }
5728
+
5729
+ return SDValue ();
5730
+ }
5731
+
5698
5732
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5699
5733
DAGCombinerInfo &DCI) const {
5700
5734
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5731,6 +5765,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5731
5765
return PerformBUILD_VECTORCombine (N, DCI);
5732
5766
case ISD::ADDRSPACECAST:
5733
5767
return combineADDRSPACECAST (N, DCI);
5768
+ case ISD::FP_ROUND:
5769
+ return PerformFP_ROUNDCombine (N, DCI);
5734
5770
}
5735
5771
return SDValue ();
5736
5772
}
0 commit comments