@@ -825,7 +825,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
825
825
// We have some custom DAG combine patterns for these nodes
826
826
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
827
827
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
828
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST});
828
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND });
829
829
830
830
// setcc for f16x2 and bf16x2 needs special handling to prevent
831
831
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5582,6 +5582,40 @@ static SDValue combineADDRSPACECAST(SDNode *N,
5582
5582
return SDValue ();
5583
5583
}
5584
5584
5585
+ static SDValue PerformFP_ROUNDCombine (SDNode *N,
5586
+ TargetLowering::DAGCombinerInfo &DCI) {
5587
+ SDLoc DL (N);
5588
+ SDValue Op = N->getOperand (0 );
5589
+ SDValue Trunc = N->getOperand (1 );
5590
+ EVT NarrowVT = N->getValueType (0 );
5591
+ EVT WideVT = Op.getValueType ();
5592
+
5593
+ // v2[b]f16 = fp_round (v2f32 A)
5594
+ // -> v2[b]f16 = (build_vector ([b]f16 = fp_round (extractelt A, 0)),
5595
+ // ([b]f16 = fp_round (extractelt A, 1)))
5596
+ if ((NarrowVT == MVT::v2bf16 || NarrowVT == MVT::v2f16) &&
5597
+ WideVT == MVT::v2f32) {
5598
+ SDValue F32Op0, F32Op1;
5599
+ if (Op.getOpcode () == ISD::BUILD_VECTOR) {
5600
+ F32Op0 = Op.getOperand (0 );
5601
+ F32Op1 = Op.getOperand (1 );
5602
+ } else {
5603
+ F32Op0 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5604
+ DCI.DAG .getIntPtrConstant (0 , DL));
5605
+ F32Op1 = DCI.DAG .getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::f32 , Op,
5606
+ DCI.DAG .getIntPtrConstant (1 , DL));
5607
+ }
5608
+ return DCI.DAG .getBuildVector (
5609
+ NarrowVT, DL,
5610
+ {DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op0,
5611
+ Trunc),
5612
+ DCI.DAG .getNode (ISD::FP_ROUND, DL, NarrowVT.getScalarType (), F32Op1,
5613
+ Trunc)});
5614
+ }
5615
+
5616
+ return SDValue ();
5617
+ }
5618
+
5585
5619
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5586
5620
DAGCombinerInfo &DCI) const {
5587
5621
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5618,6 +5652,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5618
5652
return PerformBUILD_VECTORCombine (N, DCI);
5619
5653
case ISD::ADDRSPACECAST:
5620
5654
return combineADDRSPACECAST (N, DCI);
5655
+ case ISD::FP_ROUND:
5656
+ return PerformFP_ROUNDCombine (N, DCI);
5621
5657
}
5622
5658
return SDValue ();
5623
5659
}
0 commit comments