@@ -833,7 +833,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
833
833
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
834
834
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
835
835
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
836
- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
836
+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
837
837
838
838
// setcc for f16x2 and bf16x2 needs special handling to prevent
839
839
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3092,10 +3092,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3092
3092
if (Op.getValueType () == MVT::i1)
3093
3093
return LowerLOADi1 (Op, DAG);
3094
3094
3095
- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3096
- // unaligned loads and have to handle it here.
3095
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3096
+ // handle unaligned loads and have to handle it here.
3097
3097
EVT VT = Op.getValueType ();
3098
- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3098
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
3099
3099
LoadSDNode *Load = cast<LoadSDNode>(Op);
3100
3100
EVT MemVT = Load->getMemoryVT ();
3101
3101
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3139,15 +3139,15 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3139
3139
if (VT == MVT::i1)
3140
3140
return LowerSTOREi1 (Op, DAG);
3141
3141
3142
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3143
- // stores and have to handle it here.
3144
- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3142
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3143
+ // handle unaligned stores and have to handle it here.
3144
+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
3145
3145
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3146
3146
VT, *Store->getMemOperand ()))
3147
3147
return expandUnalignedStore (Store, DAG);
3148
3148
3149
- // v2f16, v2bf16 and v2i16 don't need special handling.
3150
- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3149
+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3150
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
3151
3151
return SDValue ();
3152
3152
3153
3153
if (VT.isVector ())
@@ -3156,8 +3156,8 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3156
3156
return SDValue ();
3157
3157
}
3158
3158
3159
- SDValue
3160
- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3159
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3160
+ const SmallVectorImpl<SDValue> &Elements) {
3161
3161
SDNode *N = Op.getNode ();
3162
3162
SDValue Val = N->getOperand (1 );
3163
3163
SDLoc DL (N);
@@ -3224,6 +3224,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3224
3224
SDValue SubVector = DAG.getBuildVector (EltVT, DL, SubVectorElts);
3225
3225
Ops.push_back (SubVector);
3226
3226
}
3227
+ } else if (!Elements.empty ()) {
3228
+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
3227
3229
} else {
3228
3230
for (unsigned i = 0 ; i < NumElts; ++i) {
3229
3231
SDValue ExtVal = DAG.getNode (ISD::EXTRACT_VECTOR_ELT, DL, EltVT, Val,
@@ -3241,10 +3243,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3241
3243
DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
3242
3244
MemSD->getMemoryVT (), MemSD->getMemOperand ());
3243
3245
3244
- // return DCI.CombineTo(N, NewSt, true);
3245
3246
return NewSt;
3246
3247
}
3247
3248
3249
+ // Default variant where we don't pass in elements.
3250
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3251
+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3252
+ }
3253
+
3254
+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3255
+ SelectionDAG &DAG) const {
3256
+ return convertVectorStore (Op, DAG);
3257
+ }
3258
+
3248
3259
// st i1 v, addr
3249
3260
// =>
3250
3261
// v1 = zxt v to i16
@@ -5400,6 +5411,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5400
5411
// -->
5401
5412
// StoreRetvalV2 {a, b}
5402
5413
// likewise for V2 -> V4 case
5414
+ //
5415
+ // We also handle target-independent stores, which require us to first
5416
+ // convert to StoreV2.
5403
5417
5404
5418
std::optional<NVPTXISD::NodeType> NewOpcode;
5405
5419
switch (N->getOpcode ()) {
@@ -5425,8 +5439,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5425
5439
SDValue CurrentOp = N->getOperand (I);
5426
5440
if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
5427
5441
assert (CurrentOp.getValueType () == MVT::v2f32);
5428
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5429
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5442
+ NewOps.push_back (CurrentOp.getOperand (0 ));
5443
+ NewOps.push_back (CurrentOp.getOperand (1 ));
5430
5444
} else {
5431
5445
NewOps.clear ();
5432
5446
break ;
@@ -6197,6 +6211,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
6197
6211
return SDValue ();
6198
6212
}
6199
6213
6214
+ static SDValue PerformStoreCombine (SDNode *N,
6215
+ TargetLowering::DAGCombinerInfo &DCI) {
6216
+ // check if the store'd value can be scalarized
6217
+ SDValue StoredVal = N->getOperand (1 );
6218
+ if (StoredVal.getValueType () == MVT::v2f32 &&
6219
+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6220
+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6221
+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6222
+ }
6223
+ return SDValue ();
6224
+ }
6225
+
6200
6226
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
6201
6227
DAGCombinerInfo &DCI) const {
6202
6228
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6226,6 +6252,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6226
6252
case NVPTXISD::LoadParam:
6227
6253
case NVPTXISD::LoadParamV2:
6228
6254
return PerformLoadCombine (N, DCI);
6255
+ case ISD::STORE:
6256
+ return PerformStoreCombine (N, DCI);
6229
6257
case NVPTXISD::StoreParam:
6230
6258
case NVPTXISD::StoreParamV2:
6231
6259
case NVPTXISD::StoreParamV4:
0 commit comments