Skip to content

Commit c80645c

Browse files
committed
[NVPTX] add combiner rule to peek through bitcast of BUILD_VECTOR
1 parent 92c3d6a commit c80645c

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -864,7 +864,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
864864
// We have some custom DAG combine patterns for these nodes
865865
setTargetDAGCombine({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
866866
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
867-
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND});
867+
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
868+
ISD::TRUNCATE});
868869

869870
// setcc for f16x2 and bf16x2 needs special handling to prevent
870871
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -5858,6 +5859,49 @@ static SDValue PerformFP_ROUNDCombine(SDNode *N,
58585859
return SDValue();
58595860
}
58605861

5862+
static SDValue PerformTRUNCATECombine(SDNode *N,
5863+
TargetLowering::DAGCombinerInfo &DCI) {
5864+
SDLoc DL(N);
5865+
SDValue Op = N->getOperand(0);
5866+
EVT FromVT = Op.getValueType();
5867+
EVT ResultVT = N->getValueType(0);
5868+
5869+
if (FromVT == MVT::i64 && ResultVT == MVT::i32) {
5870+
// i32 = truncate (i64 = bitcast (v2f32 = BUILD_VECTOR (f32 A, f32 B)))
5871+
// -> i32 = bitcast (f32 A)
5872+
if (Op.getOpcode() == ISD::BITCAST) {
5873+
SDValue BV = Op.getOperand(0);
5874+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5875+
BV.getValueType() == MVT::v2f32) {
5876+
// get lower
5877+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT, BV.getOperand(0));
5878+
}
5879+
}
5880+
5881+
// i32 = truncate (i64 = srl
5882+
// (i64 = bitcast
5883+
// (v2f32 = BUILD_VECTOR (f32 A, f32 B))), 32)
5884+
// -> i32 = bitcast (f32 B)
5885+
if (Op.getOpcode() == ISD::SRL) {
5886+
if (auto *ShAmt = dyn_cast<ConstantSDNode>(Op.getOperand(1));
5887+
ShAmt && ShAmt->getAsAPIntVal() == 32) {
5888+
SDValue Cast = Op.getOperand(0);
5889+
if (Cast.getOpcode() == ISD::BITCAST) {
5890+
SDValue BV = Cast.getOperand(0);
5891+
if (BV.getOpcode() == ISD::BUILD_VECTOR &&
5892+
BV.getValueType() == MVT::v2f32) {
5893+
// get upper
5894+
return DCI.DAG.getNode(ISD::BITCAST, DL, ResultVT,
5895+
BV.getOperand(1));
5896+
}
5897+
}
5898+
}
5899+
}
5900+
}
5901+
5902+
return SDValue();
5903+
}
5904+
58615905
SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58625906
DAGCombinerInfo &DCI) const {
58635907
CodeGenOptLevel OptLevel = getTargetMachine().getOptLevel();
@@ -5896,6 +5940,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
58965940
return combineADDRSPACECAST(N, DCI);
58975941
case ISD::FP_ROUND:
58985942
return PerformFP_ROUNDCombine(N, DCI);
5943+
case ISD::TRUNCATE:
5944+
return PerformTRUNCATECombine(N, DCI);
58995945
}
59005946
return SDValue();
59015947
}

0 commit comments

Comments
 (0)