Skip to content

Commit e21b5c4

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent e4f9ae6 commit e21b5c4

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
832832
// We have some custom DAG combine patterns for these nodes
833833
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
835+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836+
ISD::TRUNCATE});
836837

837838
// setcc for f16x2 and bf16x2 needs special handling to prevent
838839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5732,6 +5733,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57325733
return SDValue();
57335734
}
57345735

5736+
static SDValue PerformTRUNCATECombine(SDNode *N,
5737+
TargetLowering::DAGCombinerInfo &DCI) {
5738+
SDLoc DL(N);
5739+
SDValue Op = N->getOperand(0);
5740+
EVT FromVT = Op.getValueType();
5741+
EVT ResultVT = N->getValueType(0);
5742+
5743+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5744+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5745+
// -> i32 = bitcast (f32 A)
5746+
if (Op.getOpcode() == ISD::BITCAST) {
5747+
SDValue BV = Op.getOperand(0);
5748+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5749+
BV.getValueType() == MVT::v2f32) {
5750+
// get lower
5751+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5752+
}
5753+
}
5754+
5755+
// i32 = truncate (i64 = srl
5756+
// (i64 = bitcast
5757+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5758+
// -> i32 = bitcast (f32 B)
5759+
if (Op.getOpcode() == ISD::SRL) {
5760+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5761+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5762+
SDValue Cast = Op.getOperand(0);
5763+
if (Cast.getOpcode() == ISD::BITCAST) {
5764+
SDValue BV = Cast.getOperand(0);
5765+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5766+
BV.getValueType() == MVT::v2f32) {
5767+
// get upper
5768+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5769+
BV.getOperand(1));
5770+
}
5771+
}
5772+
}
5773+
}
5774+
}
5775+
5776+
return SDValue();
5777+
}
5778+
57355779
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57365780
DAGCombinerInfo &DCI) const {
57375781
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5770,6 +5814,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57705814
return combineADDRSPACECAST(N, DCI);
57715815
case ISD::FP_ROUND:
57725816
return PerformFP_ROUNDCombine(N, DCI);
5817+
case ISD::TRUNCATE:
5818+
return PerformTRUNCATECombine(N, DCI);
57735819
}
57745820
return SDValue();
57755821
}

0 commit comments

Comments
 (0)