@@ -832,7 +832,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
832
832
// We have some custom DAG combine patterns for these nodes
833
833
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834
834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
835
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836
+ ISD::TRUNCATE});
836
837
837
838
// setcc for f16x2 and bf16x2 needs special handling to prevent
838
839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5734,6 +5735,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
5734
5735
return SDValue ();
5735
5736
}
5736
5737
5738
+ static SDValue PerformTRUNCATECombine (SDNode *N,
5739
+ TargetLowering::DAGCombinerInfo &DCI) {
5740
+ SDLoc DL (N);
5741
+ SDValue Op = N->getOperand (0 );
5742
+ EVT FromVT = Op.getValueType ();
5743
+ EVT ResultVT = N->getValueType (0 );
5744
+
5745
+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5746
+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5747
+ // -> i32 = bitcast (f32 A)
5748
+ if (Op.getOpcode () == ISD::BITCAST) {
5749
+ SDValue BV = Op.getOperand (0 );
5750
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5751
+ BV.getValueType () == MVT::v2f32) {
5752
+ // get lower
5753
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5754
+ }
5755
+ }
5756
+
5757
+ // i32 = truncate (i64 = srl
5758
+ // (i64 = bitcast
5759
+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5760
+ // -> i32 = bitcast (f32 B)
5761
+ if (Op.getOpcode () == ISD::SRL) {
5762
+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5763
+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5764
+ SDValue Cast = Op.getOperand (0 );
5765
+ if (Cast.getOpcode () == ISD::BITCAST) {
5766
+ SDValue BV = Cast.getOperand (0 );
5767
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5768
+ BV.getValueType () == MVT::v2f32) {
5769
+ // get upper
5770
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5771
+ BV.getOperand (1 ));
5772
+ }
5773
+ }
5774
+ }
5775
+ }
5776
+ }
5777
+
5778
+ return SDValue ();
5779
+ }
5780
+
5737
5781
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5738
5782
DAGCombinerInfo &DCI) const {
5739
5783
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5772,6 +5816,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5772
5816
return combineADDRSPACECAST (N, DCI);
5773
5817
case ISD::FP_ROUND:
5774
5818
return PerformFP_ROUNDCombine (N, DCI);
5819
+ case ISD::TRUNCATE:
5820
+ return PerformTRUNCATECombine (N, DCI);
5775
5821
}
5776
5822
return SDValue ();
5777
5823
}
0 commit comments