@@ -864,7 +864,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
864
864
// We have some custom DAG combine patterns for these nodes
865
865
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
866
866
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
867
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
867
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
868
+ ISD::TRUNCATE});
868
869
869
870
// setcc for f16x2 and bf16x2 needs special handling to prevent
870
871
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5858,6 +5859,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
5858
5859
return SDValue ();
5859
5860
}
5860
5861
5862
+ static SDValue PerformTRUNCATECombine (SDNode *N,
5863
+ TargetLowering::DAGCombinerInfo &DCI) {
5864
+ SDLoc DL (N);
5865
+ SDValue Op = N->getOperand (0 );
5866
+ EVT FromVT = Op.getValueType ();
5867
+ EVT ResultVT = N->getValueType (0 );
5868
+
5869
+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5870
+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5871
+ // -> i32 = bitcast (f32 A)
5872
+ if (Op.getOpcode () == ISD::BITCAST) {
5873
+ SDValue BV = Op.getOperand (0 );
5874
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5875
+ BV.getValueType () == MVT::v2f32) {
5876
+ // get lower
5877
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5878
+ }
5879
+ }
5880
+
5881
+ // i32 = truncate (i64 = srl
5882
+ // (i64 = bitcast
5883
+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5884
+ // -> i32 = bitcast (f32 B)
5885
+ if (Op.getOpcode () == ISD::SRL) {
5886
+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5887
+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5888
+ SDValue Cast = Op.getOperand (0 );
5889
+ if (Cast.getOpcode () == ISD::BITCAST) {
5890
+ SDValue BV = Cast.getOperand (0 );
5891
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5892
+ BV.getValueType () == MVT::v2f32) {
5893
+ // get upper
5894
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5895
+ BV.getOperand (1 ));
5896
+ }
5897
+ }
5898
+ }
5899
+ }
5900
+ }
5901
+
5902
+ return SDValue ();
5903
+ }
5904
+
5861
5905
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5862
5906
DAGCombinerInfo &DCI) const {
5863
5907
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5896,6 +5940,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5896
5940
return combineADDRSPACECAST (N, DCI);
5897
5941
case ISD::FP_ROUND:
5898
5942
return PerformFP_ROUNDCombine (N, DCI);
5943
+ case ISD::TRUNCATE:
5944
+ return PerformTRUNCATECombine (N, DCI);
5899
5945
}
5900
5946
return SDValue ();
5901
5947
}
0 commit comments