Skip to content

Commit db55ccb

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent 85dba20 commit db55ccb

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
827827
// We have some custom DAG combine patterns for these nodes
828828
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
829829
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
830-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
830+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
831+
ISD::TRUNCATE});
831832

832833
// setcc for f16x2 and bf16x2 needs special handling to prevent
833834
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5621,6 +5622,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
56215622
return SDValue();
56225623
}
56235624

5625+
static SDValue PerformTRUNCATECombine(SDNode *N,
5626+
TargetLowering::DAGCombinerInfo &DCI) {
5627+
SDLoc DL(N);
5628+
SDValue Op = N->getOperand(0);
5629+
EVT FromVT = Op.getValueType();
5630+
EVT ResultVT = N->getValueType(0);
5631+
5632+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5633+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5634+
// -> i32 = bitcast (f32 A)
5635+
if (Op.getOpcode() == ISD::BITCAST) {
5636+
SDValue BV = Op.getOperand(0);
5637+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5638+
BV.getValueType() == MVT::v2f32) {
5639+
// get lower
5640+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5641+
}
5642+
}
5643+
5644+
// i32 = truncate (i64 = srl
5645+
// (i64 = bitcast
5646+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5647+
// -> i32 = bitcast (f32 B)
5648+
if (Op.getOpcode() == ISD::SRL) {
5649+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5650+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5651+
SDValue Cast = Op.getOperand(0);
5652+
if (Cast.getOpcode() == ISD::BITCAST) {
5653+
SDValue BV = Cast.getOperand(0);
5654+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5655+
BV.getValueType() == MVT::v2f32) {
5656+
// get upper
5657+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5658+
BV.getOperand(1));
5659+
}
5660+
}
5661+
}
5662+
}
5663+
}
5664+
5665+
return SDValue();
5666+
}
5667+
56245668
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56255669
DAGCombinerInfo &DCI) const {
56265670
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5659,6 +5703,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
56595703
return combineADDRSPACECAST(N, DCI);
56605704
case ISD::FP_ROUND:
56615705
return PerformFP_ROUNDCombine(N, DCI);
5706+
case ISD::TRUNCATE:
5707+
return PerformTRUNCATECombine(N, DCI);
56625708
}
56635709
return SDValue();
56645710
}

0 commit comments

Comments
 (0)