@@ -872,6 +872,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
872
872
setOperationAction (Op, MVT::v2f32, Custom);
873
873
// Handle custom lowering for: i64 = bitcast v2f32
874
874
setOperationAction (ISD::BITCAST, MVT::v2f32, Custom);
875
+ // Handle custom lowering for: f32 = extract_vector_elt v2f32
876
+ setOperationAction (ISD::EXTRACT_VECTOR_ELT, MVT::v2f32, Custom);
875
877
}
876
878
877
879
// These map to conversion instructions for scalar FP types.
@@ -2253,6 +2255,20 @@ SDValue NVPTXTargetLowering::LowerEXTRACT_VECTOR_ELT(SDValue Op,
2253
2255
return DAG.getAnyExtOrTrunc (BFE, DL, Op->getValueType (0 ));
2254
2256
}
2255
2257
2258
+ if (VectorVT == MVT::v2f32) {
2259
+ if (Vector.getOpcode () == ISD::BITCAST) {
2260
+ // peek through v2f32 = bitcast (i64 = build_pair (i32 A, i32 B))
2261
+ // where A:i32, B:i32 = CopyFromReg (i64 = F32X2 Operation ...)
2262
+ SDValue Pair = Vector.getOperand (0 );
2263
+ assert (Pair.getOpcode () == ISD::BUILD_PAIR);
2264
+ return DAG.getNode (
2265
+ ISD::BITCAST, DL, Op.getValueType (),
2266
+ Pair.getOperand (cast<ConstantSDNode>(Index)->getZExtValue ()));
2267
+ }
2268
+ if (Vector.getOpcode () == ISD::BUILD_VECTOR)
2269
+ return Vector.getOperand (cast<ConstantSDNode>(Index)->getZExtValue ());
2270
+ }
2271
+
2256
2272
// Constant index will be matched by tablegen.
2257
2273
if (isa<ConstantSDNode>(Index.getNode ()))
2258
2274
return Op;
@@ -5565,9 +5581,22 @@ static void ReplaceF32x2Op(SDNode *N, SelectionDAG &DAG,
5565
5581
for (const SDValue &Op : N->ops ())
5566
5582
NewOps.push_back (DAG.getNode (ISD::BITCAST, DL, MVT::i64 , Op));
5567
5583
5568
- // cast i64 result of new op back to <2 x float>
5584
+ SDValue Chain = DAG.getEntryNode ();
5585
+
5586
+ // break i64 result into two i32 registers for later instructions that may
5587
+ // access element #0 or #1. otherwise, this code will be eliminated
5569
5588
SDValue NewValue = DAG.getNode (Opcode, DL, MVT::i64 , NewOps);
5570
- Results.push_back (DAG.getBitcast (OldResultTy, NewValue));
5589
+ MachineRegisterInfo &RegInfo = DAG.getMachineFunction ().getRegInfo ();
5590
+ Register DestReg = RegInfo.createVirtualRegister (
5591
+ DAG.getTargetLoweringInfo ().getRegClassFor (MVT::i64 ));
5592
+ SDValue RegCopy = DAG.getCopyToReg (Chain, DL, DestReg, NewValue);
5593
+ SDValue Explode = DAG.getNode (ISD::CopyFromReg, DL,
5594
+ {MVT::i32 , MVT::i32 , Chain.getValueType ()},
5595
+ {RegCopy, DAG.getRegister (DestReg, MVT::i64 )});
5596
+ // cast i64 result of new op back to <2 x float>
5597
+ Results.push_back (DAG.getBitcast (
5598
+ OldResultTy, DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i64 ,
5599
+ {Explode.getValue (0 ), Explode.getValue (1 )})));
5571
5600
}
5572
5601
5573
5602
void NVPTXTargetLowering::ReplaceNodeResults (
0 commit comments