@@ -827,7 +827,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
827
827
// We have some custom DAG combine patterns for these nodes
828
828
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
829
829
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
830
- ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
830
+ ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
831
+ ISD::TRUNCATE});
831
832
832
833
// setcc for f16x2 and bf16x2 needs special handling to prevent
833
834
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5621,6 +5622,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
5621
5622
return SDValue ();
5622
5623
}
5623
5624
5625
+ static SDValue PerformTRUNCATECombine (SDNode *N,
5626
+ TargetLowering::DAGCombinerInfo &DCI) {
5627
+ SDLoc DL (N);
5628
+ SDValue Op = N->getOperand (0 );
5629
+ EVT FromVT = Op.getValueType ();
5630
+ EVT ResultVT = N->getValueType (0 );
5631
+
5632
+ if (FromVT == MVT::i64 && ResultVT == MVT::i32 ) {
5633
+ // i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5634
+ // -> i32 = bitcast (f32 A)
5635
+ if (Op.getOpcode () == ISD::BITCAST) {
5636
+ SDValue BV = Op.getOperand (0 );
5637
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5638
+ BV.getValueType () == MVT::v2f32) {
5639
+ // get lower
5640
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT, BV.getOperand (0 ));
5641
+ }
5642
+ }
5643
+
5644
+ // i32 = truncate (i64 = srl
5645
+ // (i64 = bitcast
5646
+ // (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5647
+ // -> i32 = bitcast (f32 B)
5648
+ if (Op.getOpcode () == ISD::SRL) {
5649
+ if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand (1 ));
5650
+ ShAmt && ShAmt->getAsAPIntVal () == 32 ) {
5651
+ SDValue Cast = Op.getOperand (0 );
5652
+ if (Cast.getOpcode () == ISD::BITCAST) {
5653
+ SDValue BV = Cast.getOperand (0 );
5654
+ if (BV.getOpcode () == ISD::BUILD_VECTOR &&
5655
+ BV.getValueType () == MVT::v2f32) {
5656
+ // get upper
5657
+ return DCI.DAG .getNode (ISD::BITCAST, DL, ResultVT,
5658
+ BV.getOperand (1 ));
5659
+ }
5660
+ }
5661
+ }
5662
+ }
5663
+ }
5664
+
5665
+ return SDValue ();
5666
+ }
5667
+
5624
5668
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
5625
5669
DAGCombinerInfo &DCI) const {
5626
5670
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -5659,6 +5703,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
5659
5703
return combineADDRSPACECAST (N, DCI);
5660
5704
case ISD::FP_ROUND:
5661
5705
return PerformFP_ROUNDCombine (N, DCI);
5706
+ case ISD::TRUNCATE:
5707
+ return PerformTRUNCATECombine (N, DCI);
5662
5708
}
5663
5709
return SDValue ();
5664
5710
}
0 commit comments