@@ -829,7 +829,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
829
829
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
830
830
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
831
831
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
832
- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
832
+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
833
833
834
834
// setcc for f16x2 and bf16x2 needs special handling to prevent
835
835
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3143,10 +3143,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3143
3143
if (Op.getValueType () == MVT::i1)
3144
3144
return LowerLOADi1 (Op, DAG);
3145
3145
3146
- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3147
- // unaligned loads and have to handle it here.
3146
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3147
+ // handle unaligned loads and have to handle it here.
3148
3148
EVT VT = Op.getValueType ();
3149
- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3149
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
3150
3150
LoadSDNode *Load = cast<LoadSDNode>(Op);
3151
3151
EVT MemVT = Load->getMemoryVT ();
3152
3152
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3190,22 +3190,22 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3190
3190
if (VT == MVT::i1)
3191
3191
return LowerSTOREi1 (Op, DAG);
3192
3192
3193
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3194
- // stores and have to handle it here.
3195
- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3193
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3194
+ // handle unaligned stores and have to handle it here.
3195
+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
3196
3196
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3197
3197
VT, *Store->getMemOperand ()))
3198
3198
return expandUnalignedStore (Store, DAG);
3199
3199
3200
- // v2f16, v2bf16 and v2i16 don't need special handling.
3201
- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3200
+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3201
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
3202
3202
return SDValue ();
3203
3203
3204
3204
return LowerSTOREVector (Op, DAG);
3205
3205
}
3206
3206
3207
- SDValue
3208
- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3207
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3208
+ const SmallVectorImpl<SDValue> &Elements) {
3209
3209
MemSDNode *N = cast<MemSDNode>(Op.getNode ());
3210
3210
SDValue Val = N->getOperand (1 );
3211
3211
SDLoc DL (N);
@@ -3266,6 +3266,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3266
3266
NumEltsPerSubVector);
3267
3267
Ops.push_back (DAG.getBuildVector (EltVT, DL, SubVectorElts));
3268
3268
}
3269
+ } else if (!Elements.empty ()) {
3270
+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
3269
3271
} else {
3270
3272
SDValue V = DAG.getBitcast (MVT::getVectorVT (EltVT, NumElts), Val);
3271
3273
for (const unsigned I : llvm::seq (NumElts)) {
@@ -3289,10 +3291,19 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3289
3291
DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
3290
3292
N->getMemoryVT (), N->getMemOperand ());
3291
3293
3292
- // return DCI.CombineTo(N, NewSt, true);
3293
3294
return NewSt;
3294
3295
}
3295
3296
3297
+ // Default variant where we don't pass in elements.
3298
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG) {
3299
+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{});
3300
+ }
3301
+
3302
+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3303
+ SelectionDAG &DAG) const {
3304
+ return convertVectorStore (Op, DAG);
3305
+ }
3306
+
3296
3307
// st i1 v, addr
3297
3308
// =>
3298
3309
// v1 = zxt v to i16
@@ -5413,6 +5424,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5413
5424
// -->
5414
5425
// StoreRetvalV2 {a, b}
5415
5426
// likewise for V2 -> V4 case
5427
+ //
5428
+ // We also handle target-independent stores, which require us to first
5429
+ // convert to StoreV2.
5416
5430
5417
5431
std::optional<NVPTXISD::NodeType> NewOpcode;
5418
5432
switch (N->getOpcode ()) {
@@ -5438,8 +5452,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5438
5452
SDValue CurrentOp = N->getOperand (I);
5439
5453
if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
5440
5454
assert (CurrentOp.getValueType () == MVT::v2f32);
5441
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5442
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5455
+ NewOps.push_back (CurrentOp.getOperand (0 ));
5456
+ NewOps.push_back (CurrentOp.getOperand (1 ));
5443
5457
} else {
5444
5458
NewOps.clear ();
5445
5459
break ;
@@ -6216,6 +6230,18 @@ static SDValue PerformBITCASTCombine(SDNode *N,
6216
6230
return SDValue ();
6217
6231
}
6218
6232
6233
+ static SDValue PerformStoreCombine (SDNode *N,
6234
+ TargetLowering::DAGCombinerInfo &DCI) {
6235
+ // check if the store'd value can be scalarized
6236
+ SDValue StoredVal = N->getOperand (1 );
6237
+ if (StoredVal.getValueType () == MVT::v2f32 &&
6238
+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6239
+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6240
+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements);
6241
+ }
6242
+ return SDValue ();
6243
+ }
6244
+
6219
6245
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
6220
6246
DAGCombinerInfo &DCI) const {
6221
6247
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6245,6 +6271,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6245
6271
case NVPTXISD::LoadParam:
6246
6272
case NVPTXISD::LoadParamV2:
6247
6273
return PerformLoadCombine (N, DCI);
6274
+ case ISD::STORE:
6275
+ return PerformStoreCombine (N, DCI);
6248
6276
case NVPTXISD::StoreParam:
6249
6277
case NVPTXISD::StoreParamV2:
6250
6278
case NVPTXISD::StoreParamV4:
0 commit comments