@@ -865,7 +865,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
865
865
setTargetDAGCombine ({ISD::ADD, ISD::AND, ISD::EXTRACT_VECTOR_ELT, ISD::FADD,
866
866
ISD::MUL, ISD::SHL, ISD::SREM, ISD::UREM, ISD::VSELECT,
867
867
ISD::BUILD_VECTOR, ISD::ADDRSPACECAST, ISD::FP_ROUND,
868
- ISD::TRUNCATE, ISD::LOAD, ISD::BITCAST});
868
+ ISD::TRUNCATE, ISD::LOAD, ISD::STORE, ISD:: BITCAST});
869
869
870
870
// setcc for f16x2 and bf16x2 needs special handling to prevent
871
871
// legalizer's attempt to scalarize it due to v2i1 not being legal.
@@ -3242,10 +3242,10 @@ SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3242
3242
if (Op.getValueType () == MVT::i1)
3243
3243
return LowerLOADi1 (Op, DAG);
3244
3244
3245
- // v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to handle
3246
- // unaligned loads and have to handle it here.
3245
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3246
+ // handle unaligned loads and have to handle it here.
3247
3247
EVT VT = Op.getValueType ();
3248
- if (Isv2x16VT (VT) || VT == MVT::v4i8) {
3248
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) {
3249
3249
LoadSDNode *Load = cast<LoadSDNode>(Op);
3250
3250
EVT MemVT = Load->getMemoryVT ();
3251
3251
if (!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
@@ -3289,22 +3289,23 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
3289
3289
if (VT == MVT::i1)
3290
3290
return LowerSTOREi1 (Op, DAG);
3291
3291
3292
- // v2f16 is legal, so we can't rely on legalizer to handle unaligned
3293
- // stores and have to handle it here.
3294
- if ((Isv2x16VT (VT) || VT == MVT::v4i8) &&
3292
+ // v2f16/v2bf16/v2i16/v4i8/v2f32 are legal, so we can't rely on legalizer to
3293
+ // handle unaligned stores and have to handle it here.
3294
+ if ((Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 ) &&
3295
3295
!allowsMemoryAccessForAlignment (*DAG.getContext (), DAG.getDataLayout (),
3296
3296
VT, *Store->getMemOperand ()))
3297
3297
return expandUnalignedStore (Store, DAG);
3298
3298
3299
- // v2f16, v2bf16 and v2i16 don't need special handling.
3300
- if (Isv2x16VT (VT) || VT == MVT::v4i8)
3299
+ // v2f16/ v2bf16/ v2i16/v4i8/v2f32 don't need special handling.
3300
+ if (Isv2x16VT (VT) || VT == MVT::v4i8 || VT == MVT::v2f32 )
3301
3301
return SDValue ();
3302
3302
3303
3303
return LowerSTOREVector (Op, DAG);
3304
3304
}
3305
3305
3306
- SDValue
3307
- NVPTXTargetLowering::LowerSTOREVector (SDValue Op, SelectionDAG &DAG) const {
3306
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3307
+ const SmallVectorImpl<SDValue> &Elements,
3308
+ const NVPTXSubtarget &STI) {
3308
3309
MemSDNode *N = cast<MemSDNode>(Op.getNode ());
3309
3310
SDValue Val = N->getOperand (1 );
3310
3311
SDLoc DL (N);
@@ -3369,6 +3370,8 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3369
3370
NumEltsPerSubVector);
3370
3371
Ops.push_back (DAG.getBuildVector (EltVT, DL, SubVectorElts));
3371
3372
}
3373
+ } else if (!Elements.empty ()) {
3374
+ Ops.insert (Ops.end (), Elements.begin (), Elements.end ());
3372
3375
} else {
3373
3376
SDValue V = DAG.getBitcast (MVT::getVectorVT (EltVT, NumElts), Val);
3374
3377
for (const unsigned I : llvm::seq (NumElts)) {
@@ -3392,10 +3395,20 @@ NVPTXTargetLowering::LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const {
3392
3395
DAG.getMemIntrinsicNode (Opcode, DL, DAG.getVTList (MVT::Other), Ops,
3393
3396
N->getMemoryVT (), N->getMemOperand ());
3394
3397
3395
- // return DCI.CombineTo(N, NewSt, true);
3396
3398
return NewSt;
3397
3399
}
3398
3400
3401
+ // Default variant where we don't pass in elements.
3402
+ static SDValue convertVectorStore (SDValue Op, SelectionDAG &DAG,
3403
+ const NVPTXSubtarget &STI) {
3404
+ return convertVectorStore (Op, DAG, SmallVector<SDValue>{}, STI);
3405
+ }
3406
+
3407
+ SDValue NVPTXTargetLowering::LowerSTOREVector (SDValue Op,
3408
+ SelectionDAG &DAG) const {
3409
+ return convertVectorStore (Op, DAG, STI);
3410
+ }
3411
+
3399
3412
// st i1 v, addr
3400
3413
// =>
3401
3414
// v1 = zxt v to i16
@@ -5539,6 +5552,9 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5539
5552
// -->
5540
5553
// StoreRetvalV2 {a, b}
5541
5554
// likewise for V2 -> V4 case
5555
+ //
5556
+ // We also handle target-independent stores, which require us to first
5557
+ // convert to StoreV2.
5542
5558
5543
5559
std::optional<NVPTXISD::NodeType> NewOpcode;
5544
5560
switch (N->getOpcode ()) {
@@ -5564,8 +5580,8 @@ static SDValue PerformStoreCombineHelper(SDNode *N,
5564
5580
SDValue CurrentOp = N->getOperand (I);
5565
5581
if (CurrentOp->getOpcode () == ISD::BUILD_VECTOR) {
5566
5582
assert (CurrentOp.getValueType () == MVT::v2f32);
5567
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (0 ));
5568
- NewOps.push_back (CurrentOp.getNode ()-> getOperand (1 ));
5583
+ NewOps.push_back (CurrentOp.getOperand (0 ));
5584
+ NewOps.push_back (CurrentOp.getOperand (1 ));
5569
5585
} else {
5570
5586
NewOps.clear ();
5571
5587
break ;
@@ -6342,6 +6358,19 @@ static SDValue PerformBITCASTCombine(SDNode *N,
6342
6358
return SDValue ();
6343
6359
}
6344
6360
6361
+ static SDValue PerformStoreCombine (SDNode *N,
6362
+ TargetLowering::DAGCombinerInfo &DCI,
6363
+ const NVPTXSubtarget &STI) {
6364
+ // check if the store'd value can be scalarized
6365
+ SDValue StoredVal = N->getOperand (1 );
6366
+ if (StoredVal.getValueType () == MVT::v2f32 &&
6367
+ StoredVal.getOpcode () == ISD::BUILD_VECTOR) {
6368
+ SmallVector<SDValue> Elements (StoredVal->op_values ());
6369
+ return convertVectorStore (SDValue (N, 0 ), DCI.DAG , Elements, STI);
6370
+ }
6371
+ return SDValue ();
6372
+ }
6373
+
6345
6374
SDValue NVPTXTargetLowering::PerformDAGCombine (SDNode *N,
6346
6375
DAGCombinerInfo &DCI) const {
6347
6376
CodeGenOptLevel OptLevel = getTargetMachine ().getOptLevel ();
@@ -6371,6 +6400,8 @@ SDValue NVPTXTargetLowering::PerformDAGCombine(SDNode *N,
6371
6400
case NVPTXISD::LoadParam:
6372
6401
case NVPTXISD::LoadParamV2:
6373
6402
return PerformLoadCombine (N, DCI, STI);
6403
+ case ISD::STORE:
6404
+ return PerformStoreCombine (N, DCI, STI);
6374
6405
case NVPTXISD::StoreParam:
6375
6406
case NVPTXISD::StoreParamV2:
6376
6407
case NVPTXISD::StoreParamV4:
0 commit comments