@@ -828,7 +828,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828
828
// We have some custom DAG combine patterns for these nodes
829
829
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830
830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
831
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832
+ ISD::TRUNCATE});
832
833
833
834
// setcc for f16x2 and bf16x2 needs special handling to prevent
834
835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5758,6 +5759,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
5758
5759
return SDValue ();
5759
5760
}
5760
5761
5762
+ static SDValue PerformTRUNCATECombine (SDNode *N,
5763
+ TargetLowering::DAGCombinerInfo &DCI) {
5764
+ SDLoc DL (N);
5765
+ SDValue Op = N->getOperand (0 );
5766
+ EVT FromVT = Op.getValueType ();
5767
+ EVT ResultVT = N->getValueType (0 );
5768
+
5769
+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5770
+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5771
+ // -> i32 = bitcast (f32 A)
5772
+ if (Op.getOpcode () == ISD::BITCAST) {
5773
+ SDValue BV = Op.getOperand (0 );
5774
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5775
+ BV.getValueType () == MVT::v2f32) {
5776
+ // get lower
5777
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5778
+ }
5779
+ }
5780
+
5781
+ // i32 = truncate (i64 = srl
5782
+ // (i64 = bitcast
5783
+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5784
+ // -> i32 = bitcast (f32 B)
5785
+ if (Op.getOpcode () == ISD::SRL) {
5786
+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5787
+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5788
+ SDValue Cast = Op.getOperand (0 );
5789
+ if (Cast.getOpcode () == ISD::BITCAST) {
5790
+ SDValue BV = Cast.getOperand (0 );
5791
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5792
+ BV.getValueType () == MVT::v2f32) {
5793
+ // get upper
5794
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5795
+ BV.getOperand (1 ));
5796
+ }
5797
+ }
5798
+ }
5799
+ }
5800
+ }
5801
+
5802
+ return SDValue ();
5803
+ }
5804
+
5761
5805
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5762
5806
DAGCombinerInfo &DCI) const {
5763
5807
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5796,6 +5840,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5796
5840
return combineADDRSPACECAST (N, DCI);
5797
5841
case ISD::FP_ROUND:
5798
5842
return PerformFP_ROUNDCombine (N, DCI);
5843
+ case ISD::TRUNCATE:
5844
+ return PerformTRUNCATECombine (N, DCI);
5799
5845
}
5800
5846
return SDValue ();
5801
5847
}
0 commit comments