@@ -866,6 +866,14 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
866
866
setBF16OperationAction (ISD::FNEG, MVT::v2bf16, Legal, Expand);
867
867
// (would be) Library functions.
868
868
869
+ if (STI.hasF32x2Instructions ()) {
870
+ // Handle custom lowering for: v2f32 = OP v2f32, v2f32
871
+ for (const auto &Op : {ISD::FADD, ISD::FSUB, ISD::FMUL, ISD::FMA})
872
+ setOperationAction (Op, MVT::v2f32, Custom);
873
+ // Handle custom lowering for: i64 = bitcast v2f32
874
+ setOperationAction (ISD::BITCAST, MVT::v2f32, Custom);
875
+ }
876
+
869
877
// These map to conversion instructions for scalar FP types.
870
878
for (const auto &Op : {ISD::FCEIL, ISD::FFLOOR, ISD::FNEARBYINT, ISD::FRINT,
871
879
ISD::FROUNDEVEN, ISD::FTRUNC}) {
@@ -1066,6 +1074,10 @@ const char *NVPTXTargetLowering::getTargetNodeName(unsigned Opcode) const {
1066
1074
MAKE_CASE (NVPTXISD::STACKSAVE)
1067
1075
MAKE_CASE (NVPTXISD::SETP_F16X2)
1068
1076
MAKE_CASE (NVPTXISD::SETP_BF16X2)
1077
+ MAKE_CASE (NVPTXISD::FADD_F32X2)
1078
+ MAKE_CASE (NVPTXISD::FSUB_F32X2)
1079
+ MAKE_CASE (NVPTXISD::FMUL_F32X2)
1080
+ MAKE_CASE (NVPTXISD::FMA_F32X2)
1069
1081
MAKE_CASE (NVPTXISD::Dummy)
1070
1082
MAKE_CASE (NVPTXISD::MUL_WIDE_SIGNED)
1071
1083
MAKE_CASE (NVPTXISD::MUL_WIDE_UNSIGNED)
@@ -2099,24 +2111,58 @@ SDValue NVPTXTargetLowering::LowerBITCAST(SDValue Op, SelectionDAG &DAG) const {
2099
2111
// Handle bitcasting from v2i8 without hitting the default promotion
2100
2112
// strategy which goes through stack memory.
2101
2113
EVT FromVT = Op->getOperand (0 )->getValueType (0 );
2102
- if (FromVT != MVT::v2i8) {
2103
- return Op;
2104
- }
2105
-
2106
- // Pack vector elements into i16 and bitcast to final type
2107
- SDLoc DL (Op);
2108
- SDValue Vec0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2109
- Op->getOperand (0 ), DAG.getIntPtrConstant (0 , DL));
2110
- SDValue Vec1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2111
- Op->getOperand (0 ), DAG.getIntPtrConstant (1 , DL));
2112
- SDValue Extend0 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec0);
2113
- SDValue Extend1 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec1);
2114
- SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
2115
- SDValue AsInt = DAG.getNode (
2116
- ISD::OR, DL, MVT::i16 ,
2117
- {Extend0, DAG.getNode (ISD::SHL, DL, MVT::i16 , {Extend1, Const8})});
2118
2114
EVT ToVT = Op->getValueType (0 );
2119
- return MaybeBitcast (DAG, DL, ToVT, AsInt);
2115
+ SDLoc DL (Op);
2116
+
2117
+ if (FromVT == MVT::v2i8) {
2118
+ // Pack vector elements into i16 and bitcast to final type
2119
+ SDValue Vec0 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2120
+ Op->getOperand (0 ), DAG.getIntPtrConstant (0 , DL));
2121
+ SDValue Vec1 = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, MVT::i8 ,
2122
+ Op->getOperand (0 ), DAG.getIntPtrConstant (1 , DL));
2123
+ SDValue Extend0 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec0);
2124
+ SDValue Extend1 = DAG.getNode (ISD::ZERO_EXTEND, DL, MVT::i16 , Vec1);
2125
+ SDValue Const8 = DAG.getConstant (8 , DL, MVT::i16 );
2126
+ SDValue AsInt = DAG.getNode (
2127
+ ISD::OR, DL, MVT::i16 ,
2128
+ {Extend0, DAG.getNode (ISD::SHL, DL, MVT::i16 , {Extend1, Const8})});
2129
+ EVT ToVT = Op->getValueType (0 );
2130
+ return MaybeBitcast (DAG, DL, ToVT, AsInt);
2131
+ }
2132
+
2133
+ if (FromVT == MVT::v2f32) {
2134
+ assert (ToVT == MVT::i64 );
2135
+
2136
+ // A bitcast to i64 from v2f32.
2137
+ // See if we can legalize the operand.
2138
+ const SDValue &Operand = Op->getOperand (0 );
2139
+ if (Operand.getOpcode () == ISD::BUILD_VECTOR) {
2140
+ const SDValue &BVOp0 = Operand.getOperand (0 );
2141
+ const SDValue &BVOp1 = Operand.getOperand (1 );
2142
+
2143
+ auto CastToAPInt = [](SDValue Op) -> APInt {
2144
+ if (Op->isUndef ())
2145
+ return APInt (64 , 0 ); // undef values default to 0
2146
+ return cast<ConstantFPSDNode>(Op)->getValueAPF ().bitcastToAPInt ().zext (
2147
+ 64 );
2148
+ };
2149
+
2150
+ if ((BVOp0->isUndef () || isa<ConstantFPSDNode>(BVOp0)) &&
2151
+ (BVOp1->isUndef () || isa<ConstantFPSDNode>(BVOp1))) {
2152
+ // cast two constants
2153
+ APInt Value (64 , 0 );
2154
+ Value = CastToAPInt (BVOp0) | CastToAPInt (BVOp1).shl (32 );
2155
+ SDValue Const = DAG.getConstant (Value, DL, MVT::i64 );
2156
+ return DAG.getBitcast (ToVT, Const);
2157
+ }
2158
+
2159
+ // otherwise build an i64
2160
+ return DAG.getNode (ISD::BUILD_PAIR, DL, MVT::i64 ,
2161
+ DAG.getBitcast (MVT::i32 , BVOp0),
2162
+ DAG.getBitcast (MVT::i32 , BVOp1));
2163
+ }
2164
+ }
2165
+ return Op;
2120
2166
}
2121
2167
2122
2168
// We can init constant f16x2/v2i16/v4i8 with a single .b32 move. Normally it
@@ -3055,6 +3101,13 @@ bool NVPTXTargetLowering::splitValueIntoRegisterParts(
3055
3101
return false ;
3056
3102
}
3057
3103
3104
+ const TargetRegisterClass *
3105
+ NVPTXTargetLowering::getRegClassFor (MVT VT, bool isDivergent) const {
3106
+ if (VT == MVT::v2f32)
3107
+ return &NVPTX::Int64RegsRegClass;
3108
+ return TargetLowering::getRegClassFor (VT, isDivergent);
3109
+ }
3110
+
3058
3111
// This creates target external symbol for a function parameter.
3059
3112
// Name of the symbol is composed from its index and the function name.
3060
3113
// Negative index corresponds to special parameter (unsized array) used for
@@ -5055,10 +5108,10 @@ static SDValue PerformEXTRACTCombine(SDNode *N,
5055
5108
IsPTXVectorType (VectorVT.getSimpleVT ()))
5056
5109
return SDValue (); // Native vector loads already combine nicely w/
5057
5110
// extract_vector_elt.
5058
- // Don't mess with singletons or v2*16, v4i8 and v8i8 types, we already
5111
+ // Don't mess with singletons or v2*16, v4i8, v8i8, or v2f32 types, we already
5059
5112
// handle them OK.
5060
5113
if (VectorVT.getVectorNumElements () == 1 || Isv2x16VT (VectorVT) ||
5061
- VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8)
5114
+ VectorVT == MVT::v4i8 || VectorVT == MVT::v8i8 || VectorVT == MVT::v2f32 )
5062
5115
return SDValue ();
5063
5116
5064
5117
// Don't mess with undef values as sra may be simplified to 0, not undef.
@@ -5478,6 +5531,45 @@ static void ReplaceCopyFromReg_128(SDNode *N, SelectionDAG &DAG,
5478
5531
Results.push_back (NewValue.getValue (3 ));
5479
5532
}
5480
5533
5534
+ static void ReplaceF32x2Op (SDNode *N, SelectionDAG &DAG,
5535
+ SmallVectorImpl<SDValue> &Results,
5536
+ bool UseFTZ) {
5537
+ SDLoc DL (N);
5538
+ EVT OldResultTy = N->getValueType (0 ); // <2 x float>
5539
+ assert (OldResultTy == MVT::v2f32 && " Unexpected result type for F32x2 op!" );
5540
+
5541
+ SmallVector<SDValue> NewOps;
5542
+
5543
+ // whether we use FTZ (TODO)
5544
+
5545
+ // replace with NVPTX F32x2 op:
5546
+ unsigned Opcode;
5547
+ switch (N->getOpcode ()) {
5548
+ case ISD::FADD:
5549
+ Opcode = NVPTXISD::FADD_F32X2;
5550
+ break ;
5551
+ case ISD::FSUB:
5552
+ Opcode = NVPTXISD::FSUB_F32X2;
5553
+ break ;
5554
+ case ISD::FMUL:
5555
+ Opcode = NVPTXISD::FMUL_F32X2;
5556
+ break ;
5557
+ case ISD::FMA:
5558
+ Opcode = NVPTXISD::FMA_F32X2;
5559
+ break ;
5560
+ default :
5561
+ llvm_unreachable (" Unexpected opcode" );
5562
+ }
5563
+
5564
+ // bitcast operands: <2 x float> -> i64
5565
+ for (const SDValue &Op : N->ops ())
5566
+ NewOps.push_back (DAG.getNode (ISD::BITCAST, DL, MVT::i64 , Op));
5567
+
5568
+ // cast i64 result of new op back to <2 x float>
5569
+ SDValue NewValue = DAG.getNode (Opcode, DL, MVT::i64 , NewOps);
5570
+ Results.push_back (DAG.getBitcast (OldResultTy, NewValue));
5571
+ }
5572
+
5481
5573
void NVPTXTargetLowering::ReplaceNodeResults (
5482
5574
SDNode *N, SmallVectorImpl<SDValue> &Results, SelectionDAG &DAG) const {
5483
5575
switch (N->getOpcode ()) {
@@ -5495,6 +5587,12 @@ void NVPTXTargetLowering::ReplaceNodeResults(
5495
5587
case ISD::CopyFromReg:
5496
5588
ReplaceCopyFromReg_128 (N, DAG, Results);
5497
5589
return ;
5590
+ case ISD::FADD:
5591
+ case ISD::FSUB:
5592
+ case ISD::FMUL:
5593
+ case ISD::FMA:
5594
+ ReplaceF32x2Op (N, DAG, Results, useF32FTZ (DAG.getMachineFunction ()));
5595
+ return ;
5498
5596
}
5499
5597
}
5500
5598
0 commit comments