Skip to content

Commit d61734e

Browse files
committed
write F32X2 result into two i32 registers
Allows better codegen as each register can be forwarded through subsequent EXTRACT_VECTOR_ELT nodes.
1 parent 60a73cc commit d61734e

File tree

3 files changed

+47
-2
lines changed

3 files changed

+47
-2
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,10 @@ void NVPTXDAGToDAGISel::Select(SDNode *N) {
196196
SelectI128toV2I64(N);
197197
return;
198198
}
199+
if (N->getOperand(1).getValueType() == MVT::i64 && N->getNumValues() == 3) {
200+
SelectI64ToV2I32(N);
201+
return;
202+
}
199203
break;
200204
}
201205
case ISD::FADD:
@@ -2795,6 +2799,17 @@ void NVPTXDAGToDAGISel::SelectI128toV2I64(SDNode *N) {
27952799
ReplaceNode(N, Mov);
27962800
}
27972801

2802+
void NVPTXDAGToDAGISel::SelectI64ToV2I32(SDNode *N) {
2803+
SDValue Ch = N->getOperand(0);
2804+
SDValue Src = N->getOperand(1);
2805+
SDLoc DL(N);
2806+
2807+
SDNode *Mov = CurDAG->getMachineNode(NVPTX::I64toV2I32, DL,
2808+
{MVT::i32, MVT::i32, Ch.getValueType()},
2809+
{Src, Ch});
2810+
ReplaceNode(N, Mov);
2811+
}
2812+
27982813
/// GetConvertOpcode - Returns the CVT_ instruction opcode that implements a
27992814
/// conversion from \p SrcTy to \p DestTy.
28002815
unsigned NVPTXDAGToDAGISel::GetConvertOpcode(MVT DestTy, MVT SrcTy,

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ class LLVM_LIBRARY_VISIBILITY NVPTXDAGToDAGISel : public SelectionDAGISel {
9292
bool tryEXTRACT_VECTOR_ELEMENT(SDNode *N);
9393
void SelectV2I64toI128(SDNode *N);
9494
void SelectI128toV2I64(SDNode *N);
95+
void SelectI64ToV2I32(SDNode *N);
9596
void SelectCpAsyncBulkG2S(SDNode *N);
9697
void SelectCpAsyncBulkS2G(SDNode *N);
9798
void SelectCpAsyncBulkPrefetchL2(SDNode *N);

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
872872
setOperationAction(Op, MVT::v2f32, Custom);
873873
// Handle custom lowering for: i64 = bitcast v2f32
874874
setOperationAction(ISD::BITCAST, MVT::v2f32, Custom);
875+
// Handle custom lowering for: f32 = extract_vector_elt v2f32
876+
setOperationAction(ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
875877
}
876878

877879
// These map to conversion instructions for scalar FP types.
@@ -2253,6 +2255,20 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
22532255
return DAG.getAnyExtOrTrunc(BFE, DL, Op->getValueType(0));
22542256
}
22552257

2258+
if (VectorVT == MVT::v2f32) {
2259+
if (Vector.getOpcode() == ISD::BITCAST) {
2260+
// peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
2261+
// where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
2262+
SDValue Pair = Vector.getOperand(0);
2263+
assert(Pair.getOpcode() == ISD::BUILD_PAIR);
2264+
return DAG.getNode(
2265+
ISD::BITCAST, DL, Op.getValueType(),
2266+
Pair.getOperand(cast<ConstantSDNode>(Index)->getZExtValue()));
2267+
}
2268+
if (Vector.getOpcode() == ISD::BUILD_VECTOR)
2269+
return Vector.getOperand(cast<ConstantSDNode>(Index)->getZExtValue());
2270+
}
2271+
22562272
// Constant index will be matched by tablegen.
22572273
if (isa<ConstantSDNode>(Index.getNode()))
22582274
return Op;
@@ -5565,9 +5581,22 @@ static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
55655581
for (const SDValue &Op : N->ops())
55665582
NewOps.push_back(DAG.getNode(ISD::BITCAST, DL, MVT::i64, Op));
55675583

5568-
// cast i64 result of new op back to <2 x float>
5584+
SDValue Chain = DAG.getEntryNode();
5585+
5586+
// break i64 result into two i32 registers for later instructions that may
5587+
// access element #0 or #1. otherwise, this code will be eliminated
55695588
SDValue NewValue = DAG.getNode(Opcode, DL, MVT::i64, NewOps);
5570-
Results.push_back(DAG.getBitcast(OldResultTy, NewValue));
5589+
MachineRegisterInfo &RegInfo = DAG.getMachineFunction().getRegInfo();
5590+
Register DestReg = RegInfo.createVirtualRegister(
5591+
DAG.getTargetLoweringInfo().getRegClassFor(MVT::i64));
5592+
SDValue RegCopy = DAG.getCopyToReg(Chain, DL, DestReg, NewValue);
5593+
SDValue Explode = DAG.getNode(ISD::CopyFromReg, DL,
5594+
{MVT::i32, MVT::i32, Chain.getValueType()},
5595+
{RegCopy, DAG.getRegister(DestReg, MVT::i64)});
5596+
// cast i64 result of new op back to <2 x float>
5597+
Results.push_back(DAG.getBitcast(
5598+
OldResultTy, DAG.getNode(ISD::BUILD_PAIR, DL, MVT::i64,
5599+
{Explode.getValue(0), Explode.getValue(1)})));
55715600
}
55725601

55735602
void NVPTXTargetLowering::ReplaceNodeResults(

0 commit comments

Comments
 (0)