Skip to content

Commit 7b0d4db

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent e1b2232 commit 7b0d4db

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -828,7 +828,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
828828
// We have some custom DAG combine patterns for these nodes
829829
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
831+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832+
ISD::TRUNCATE});
832833

833834
// setcc for f16x2 and bf16x2 needs special handling to prevent
834835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5758,6 +5759,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57585759
return SDValue();
57595760
}
57605761

5762+
static SDValue PerformTRUNCATECombine(SDNode *N,
5763+
TargetLowering::DAGCombinerInfo &DCI) {
5764+
SDLoc DL(N);
5765+
SDValue Op = N->getOperand(0);
5766+
EVT FromVT = Op.getValueType();
5767+
EVT ResultVT = N->getValueType(0);
5768+
5769+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5770+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5771+
// -> i32 = bitcast (f32 A)
5772+
if (Op.getOpcode() == ISD::BITCAST) {
5773+
SDValue BV = Op.getOperand(0);
5774+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5775+
BV.getValueType() == MVT::v2f32) {
5776+
// get lower
5777+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5778+
}
5779+
}
5780+
5781+
// i32 = truncate (i64 = srl
5782+
// (i64 = bitcast
5783+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5784+
// -> i32 = bitcast (f32 B)
5785+
if (Op.getOpcode() == ISD::SRL) {
5786+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5787+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5788+
SDValue Cast = Op.getOperand(0);
5789+
if (Cast.getOpcode() == ISD::BITCAST) {
5790+
SDValue BV = Cast.getOperand(0);
5791+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5792+
BV.getValueType() == MVT::v2f32) {
5793+
// get upper
5794+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5795+
BV.getOperand(1));
5796+
}
5797+
}
5798+
}
5799+
}
5800+
}
5801+
5802+
return SDValue();
5803+
}
5804+
57615805
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57625806
DAGCombinerInfo &DCI) const {
57635807
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5796,6 +5840,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57965840
return combineADDRSPACECAST(N, DCI);
57975841
case ISD::FP_ROUND:
57985842
return PerformFP_ROUNDCombine(N, DCI);
5843+
case ISD::TRUNCATE:
5844+
return PerformTRUNCATECombine(N, DCI);
57995845
}
58005846
return SDValue();
58015847
}

0 commit comments

Comments
 (0)