@@ -5089,11 +5089,13 @@ PerformFADDCombineWithOperands(SDNode *N, SDValue N0, SDValue N1,
5089
5089
return SDValue ();
5090
5090
}
5091
5091
5092
+ // / OverrideVT - allows overriding result and memory type
5092
5093
static std::optional<std::pair<SDValue, SDValue>>
5093
- convertVectorLoad (SDNode *N, SelectionDAG &DAG, bool BuildVector) {
5094
+ convertVectorLoad (SDNode *N, SelectionDAG &DAG, bool BuildVector,
5095
+ std::optional<EVT> OverrideVT = std::nullopt) {
5094
5096
LoadSDNode *LD = cast<LoadSDNode>(N);
5095
- const EVT ResVT = LD->getValueType (0 );
5096
- const EVT MemVT = LD->getMemoryVT ();
5097
+ const EVT ResVT = OverrideVT. value_or ( LD->getValueType (0 ) );
5098
+ const EVT MemVT = OverrideVT. value_or ( LD->getMemoryVT () );
5097
5099
5098
5100
// If we're doing sign/zero extension as part of the load, avoid lowering to
5099
5101
// a LoadV node. TODO: consider relaxing this restriction.
@@ -5147,33 +5149,31 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector) {
5147
5149
// pass along the extension information
5148
5150
OtherOps.push_back (DAG.getIntPtrConstant (LD->getExtensionType (), DL));
5149
5151
5150
- SDValue NewLD = DAG.getMemIntrinsicNode (Opcode, DL, LdResVTs, OtherOps,
5151
- LD->getMemoryVT (),
5152
+ SDValue NewLD = DAG.getMemIntrinsicNode (Opcode, DL, LdResVTs, OtherOps, MemVT,
5152
5153
LD->getMemOperand ());
5153
-
5154
- SmallVector<SDValue> ScalarRes;
5155
- if (EltVT.isVector ()) {
5156
- assert (EVT (EltVT.getVectorElementType ()) == ResVT.getVectorElementType ());
5157
- assert (NumElts * EltVT.getVectorNumElements () ==
5158
- ResVT.getVectorNumElements ());
5159
- // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5160
- // into individual elements.
5161
- for (const unsigned I : llvm::seq (NumElts)) {
5162
- SDValue SubVector = NewLD.getValue (I);
5163
- DAG.ExtractVectorElements (SubVector, ScalarRes);
5164
- }
5165
- } else {
5166
- for (const unsigned I : llvm::seq (NumElts)) {
5167
- SDValue Res = NewLD.getValue (I);
5168
- if (LoadEltVT != EltVT)
5169
- Res = DAG.getNode (ISD::TRUNCATE, DL, EltVT, Res);
5170
- ScalarRes.push_back (Res);
5171
- }
5172
- }
5173
-
5174
5154
SDValue LoadChain = NewLD.getValue (NumElts);
5175
5155
5176
5156
if (BuildVector) {
5157
+ SmallVector<SDValue> ScalarRes;
5158
+ if (EltVT.isVector ()) {
5159
+ assert (EVT (EltVT.getVectorElementType ()) == ResVT.getVectorElementType ());
5160
+ assert (NumElts * EltVT.getVectorNumElements () ==
5161
+ ResVT.getVectorNumElements ());
5162
+ // Generate EXTRACT_VECTOR_ELTs to split v2[i,f,bf]16/v4i8 subvectors back
5163
+ // into individual elements.
5164
+ for (const unsigned I : llvm::seq (NumElts)) {
5165
+ SDValue SubVector = NewLD.getValue (I);
5166
+ DAG.ExtractVectorElements (SubVector, ScalarRes);
5167
+ }
5168
+ } else {
5169
+ for (const unsigned I : llvm::seq (NumElts)) {
5170
+ SDValue Res = NewLD.getValue (I);
5171
+ if (LoadEltVT != EltVT)
5172
+ Res = DAG.getNode (ISD::TRUNCATE, DL, EltVT, Res);
5173
+ ScalarRes.push_back (Res);
5174
+ }
5175
+ }
5176
+
5177
5177
const MVT BuildVecVT =
5178
5178
MVT::getVectorVT (EltVT.getScalarType (), ScalarRes.size ());
5179
5179
SDValue BuildVec = DAG.getBuildVector (BuildVecVT, DL, ScalarRes);
@@ -5188,23 +5188,20 @@ convertVectorLoad(SDNode *N, SelectionDAG &DAG, bool BuildVector) {
5188
5188
static SDValue PerformLoadCombine (SDNode *N,
5189
5189
TargetLowering::DAGCombinerInfo &DCI) {
5190
5190
auto *MemN = cast<MemSDNode>(N);
5191
- EVT MemVT = MemN->getMemoryVT ();
5192
-
5193
- // ignore volatile loads
5194
- if (MemN->isVolatile ())
5195
- return SDValue ();
5196
-
5197
5191
// only operate on vectors of f32s / i64s
5198
- if (!MemVT.isVector ())
5192
+ if (EVT MemVT = MemN->getMemoryVT ();
5193
+ !(MemVT == MVT::i64 ||
5194
+ (MemVT.isVector () && (MemVT.getVectorElementType () == MVT::f32 ||
5195
+ MemVT.getVectorElementType () == MVT::i64 ))))
5199
5196
return SDValue ();
5200
5197
5201
- EVT ElementVT = MemVT. getVectorElementType ();
5202
- if (!(ElementVT == MVT:: f32 ||
5203
- (ElementVT == MVT::i64 && N-> getOpcode () != ISD::LOAD)))
5204
- return SDValue ( );
5198
+ const unsigned OrigNumResults =
5199
+ llvm::count_if (N-> values (), []( const auto &VT) {
5200
+ return VT == MVT::i64 || VT == MVT:: f32 || VT. isVector ();
5201
+ } );
5205
5202
5206
5203
SmallDenseMap<SDNode *, unsigned > ExtractElts;
5207
- SDNode *ProxyReg = nullptr ;
5204
+ SmallVector< SDNode *> ProxyRegs (OrigNumResults, nullptr ) ;
5208
5205
SmallVector<std::pair<SDNode *, unsigned /* offset*/ >> WorkList{{N, 0 }};
5209
5206
while (!WorkList.empty ()) {
5210
5207
auto [V, Offset] = WorkList.pop_back_val ();
@@ -5217,8 +5214,14 @@ static SDValue PerformLoadCombine(SDNode *N,
5217
5214
5218
5215
SDNode *User = U.getUser ();
5219
5216
if (User->getOpcode () == NVPTXISD::ProxyReg) {
5217
+ Offset = U.getResNo () * 2 ;
5218
+ SDNode *&ProxyReg = ProxyRegs[Offset / 2 ];
5219
+
5220
+ // We shouldn't have multiple proxy regs for the same value from the
5221
+ // load, but bail out anyway since we don't handle this.
5220
5222
if (ProxyReg)
5221
- return SDValue (); // bail out if we've seen a proxy reg?
5223
+ return SDValue ();
5224
+
5222
5225
ProxyReg = User;
5223
5226
} else if (User->getOpcode () == ISD::BITCAST &&
5224
5227
User->getValueType (0 ) == MVT::v2f32 &&
@@ -5308,9 +5311,18 @@ static SDValue PerformLoadCombine(SDNode *N,
5308
5311
if (NewGlueIdx)
5309
5312
NewGlue = NewLoad.getValue (*NewGlueIdx);
5310
5313
} else if (N->getOpcode () == ISD::LOAD) { // rewrite a load
5311
- if (auto Result = convertVectorLoad (N, DCI.DAG , /* BuildVector=*/ false )) {
5314
+ std::optional<EVT> CastToType;
5315
+ EVT ResVT = N->getValueType (0 );
5316
+ if (ResVT == MVT::i64 ) {
5317
+ // ld.b64 is treated as a vector by subsequent code
5318
+ CastToType = MVT::v2f32;
5319
+ }
5320
+ if (auto Result =
5321
+ convertVectorLoad (N, DCI.DAG , /* BuildVector=*/ false , CastToType)) {
5312
5322
std::tie (NewLoad, NewChain) = *Result;
5313
- NumElts = MemVT.getVectorNumElements ();
5323
+ NumElts =
5324
+ CastToType.value_or (cast<MemSDNode>(NewLoad.getNode ())->getMemoryVT ())
5325
+ .getVectorNumElements ();
5314
5326
if (NewLoad->getValueType (NewLoad->getNumValues () - 1 ) == MVT::Glue)
5315
5327
NewGlue = NewLoad.getValue (NewLoad->getNumValues () - 1 );
5316
5328
}
@@ -5322,54 +5334,65 @@ static SDValue PerformLoadCombine(SDNode *N,
5322
5334
// (3) begin rewriting uses
5323
5335
SmallVector<SDValue> NewOutputsF32;
5324
5336
5325
- if (ProxyReg ) {
5326
- // scalarize proxyreg , but first rewrite all uses of chain and glue from the
5327
- // old load to the new load
5337
+ if (llvm::any_of (ProxyRegs, []( const SDNode *PR) { return PR != nullptr ; }) ) {
5338
+ // scalarize proxy regs , but first rewrite all uses of chain and glue from
5339
+ // the old load to the new load
5328
5340
DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5329
5341
DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5330
5342
5331
- // Update the new chain and glue to be old inputs to the proxyreg, if they
5332
- // came from an intervening instruction between this proxyreg and the
5333
- // original load (ex: callseq_end). Other than bitcasts and extractelts, we
5334
- // followed all other nodes by chain and glue accesses.
5335
- if (SDValue OldInChain = ProxyReg->getOperand (0 ); OldInChain.getNode () != N)
5343
+ for (unsigned ProxyI = 0 , ProxyE = ProxyRegs.size (); ProxyI != ProxyE;
5344
+ ++ProxyI) {
5345
+ SDNode *ProxyReg = ProxyRegs[ProxyI];
5346
+
5347
+ // no proxy reg might mean this result is unused
5348
+ if (!ProxyReg)
5349
+ continue ;
5350
+
5351
+ // Update the new chain and glue to be old inputs to the proxyreg, if they
5352
+ // came from an intervening instruction between this proxyreg and the
5353
+ // original load (ex: callseq_end). Other than bitcasts and extractelts,
5354
+ // we followed all other nodes by chain and glue accesses.
5355
+ if (SDValue OldInChain = ProxyReg->getOperand (0 );
5356
+ OldInChain.getNode () != N)
5336
5357
NewChain = OldInChain;
5337
- if (SDValue OldInGlue = ProxyReg->getOperand (2 ); OldInGlue.getNode () != N)
5358
+ if (SDValue OldInGlue = ProxyReg->getOperand (2 ); OldInGlue.getNode () != N)
5338
5359
NewGlue = OldInGlue;
5339
5360
5340
- // update OldChain, OldGlue to the outputs of ProxyReg, which we will
5341
- // replace later
5342
- OldChain = SDValue (ProxyReg, 1 );
5343
- OldGlue = SDValue (ProxyReg, 2 );
5344
-
5345
- // generate the scalar proxy regs
5346
- for (unsigned I = 0 , E = NumElts; I != E; ++I) {
5347
- SDValue ProxyRegElem =
5348
- DCI.DAG .getNode (NVPTXISD::ProxyReg, SDLoc (ProxyReg),
5349
- DCI.DAG .getVTList (MVT::f32 , MVT::Other, MVT::Glue),
5350
- {NewChain, NewLoad.getValue (I), NewGlue});
5351
- NewChain = ProxyRegElem.getValue (1 );
5352
- NewGlue = ProxyRegElem.getValue (2 );
5353
- NewOutputsF32.push_back (ProxyRegElem);
5361
+ // update OldChain, OldGlue to the outputs of ProxyReg, which we will
5362
+ // replace later
5363
+ OldChain = SDValue (ProxyReg, 1 );
5364
+ OldGlue = SDValue (ProxyReg, 2 );
5365
+
5366
+ // generate the scalar proxy regs
5367
+ for (unsigned I = 0 , E = 2 ; I != E; ++I) {
5368
+ SDValue ProxyRegElem = DCI.DAG .getNode (
5369
+ NVPTXISD::ProxyReg, SDLoc (ProxyReg),
5370
+ DCI.DAG .getVTList (MVT::f32 , MVT::Other, MVT::Glue),
5371
+ {NewChain, NewLoad.getValue (ProxyI * 2 + I), NewGlue});
5372
+ NewChain = ProxyRegElem.getValue (1 );
5373
+ NewGlue = ProxyRegElem.getValue (2 );
5374
+ NewOutputsF32.push_back (ProxyRegElem);
5375
+ }
5376
+
5377
+ // replace all uses of the glue and chain from the old proxy reg
5378
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5379
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5354
5380
}
5355
5381
} else {
5356
5382
for (unsigned I = 0 , E = NumElts; I != E; ++I)
5357
5383
if (NewLoad->getValueType (I) == MVT::f32 )
5358
5384
NewOutputsF32.push_back (NewLoad.getValue (I));
5385
+
5386
+ // replace all glue and chain nodes
5387
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5388
+ if (OldGlue)
5389
+ DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5359
5390
}
5360
5391
5361
- // now, for all extractelts, replace them with one of the new outputs
5392
+ // replace all extractelts with the new outputs
5362
5393
for (auto &[Extract, Index] : ExtractElts)
5363
5394
DCI.CombineTo (Extract, NewOutputsF32[Index], false );
5364
5395
5365
- // now replace all glue and chain nodes
5366
- DCI.DAG .ReplaceAllUsesOfValueWith (OldChain, NewChain);
5367
- if (OldGlue)
5368
- DCI.DAG .ReplaceAllUsesOfValueWith (OldGlue, NewGlue);
5369
-
5370
- // cleanup
5371
- if (ProxyReg)
5372
- DCI.recursivelyDeleteUnusedNodes (ProxyReg);
5373
5396
return SDValue ();
5374
5397
}
5375
5398
0 commit comments