Skip to content

Commit 24f7f11

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent ca194b2 commit 24f7f11

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
832832
// We have some custom DAG combine patterns for these nodes
833833
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
835+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836+
ISD::TRUNCATE});
836837

837838
// setcc for f16x2 and bf16x2 needs special handling to prevent
838839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5734,6 +5735,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
57345735
return SDValue();
57355736
}
57365737

5738+
static SDValue PerformTRUNCATECombine(SDNode *N,
5739+
TargetLowering::DAGCombinerInfo &DCI) {
5740+
SDLoc DL(N);
5741+
SDValue Op = N->getOperand(0);
5742+
EVT FromVT = Op.getValueType();
5743+
EVT ResultVT = N->getValueType(0);
5744+
5745+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5746+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5747+
// -> i32 = bitcast (f32 A)
5748+
if (Op.getOpcode() == ISD::BITCAST) {
5749+
SDValue BV = Op.getOperand(0);
5750+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5751+
BV.getValueType() == MVT::v2f32) {
5752+
// get lower
5753+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5754+
}
5755+
}
5756+
5757+
// i32 = truncate (i64 = srl
5758+
// (i64 = bitcast
5759+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5760+
// -> i32 = bitcast (f32 B)
5761+
if (Op.getOpcode() == ISD::SRL) {
5762+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5763+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5764+
SDValue Cast = Op.getOperand(0);
5765+
if (Cast.getOpcode() == ISD::BITCAST) {
5766+
SDValue BV = Cast.getOperand(0);
5767+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5768+
BV.getValueType() == MVT::v2f32) {
5769+
// get upper
5770+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5771+
BV.getOperand(1));
5772+
}
5773+
}
5774+
}
5775+
}
5776+
}
5777+
5778+
return SDValue();
5779+
}
5780+
57375781
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57385782
DAGCombinerInfo &DCI) const {
57395783
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5772,6 +5816,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
57725816
return combineADDRSPACECAST(N, DCI);
57735817
case ISD::FP_ROUND:
57745818
return PerformFP_ROUNDCombine(N, DCI);
5819+
case ISD::TRUNCATE:
5820+
return PerformTRUNCATECombine(N, DCI);
57755821
}
57765822
return SDValue();
57775823
}

0 commit comments

Comments
 (0)