@@ -582,6 +582,7 @@ namespace {
582
582
SDValue reassociateReduction(unsigned RedOpc, unsigned Opc, const SDLoc &DL,
583
583
EVT VT, SDValue N0, SDValue N1,
584
584
SDNodeFlags Flags = SDNodeFlags());
585
+ SDValue foldReductionWithUndefLane(SDNode *N);
585
586
586
587
SDValue visitShiftByConstant(SDNode *N);
587
588
@@ -1349,6 +1350,75 @@ SDValue DAGCombiner::reassociateReduction(unsigned RedOpc, unsigned Opc,
1349
1350
return SDValue();
1350
1351
}
1351
1352
1353
+ // Convert:
1354
+ // (op.x2 (vector_shuffle<i,u> A), B) -> <(op A:i, B:0) undef>
1355
+ // ...or...
1356
+ // (op.x2 (vector_shuffle<u,i> A), B) -> <undef (op A:i, B:1)>
1357
+ // ...where i is a valid index and u is poison.
1358
+ SDValue DAGCombiner::foldReductionWithUndefLane(SDNode *N) {
1359
+ const EVT VectorVT = N->getValueType(0);
1360
+
1361
+ // Only support 2-packed vectors for now.
1362
+ if (!VectorVT.isVector() || VectorVT.isScalableVector()
1363
+ || VectorVT.getVectorNumElements() != 2)
1364
+ return SDValue();
1365
+
1366
+ // If the operation is already unsupported, we don't need to do this
1367
+ // operation.
1368
+ if (!TLI.isOperationLegal(N->getOpcode(), VectorVT))
1369
+ return SDValue();
1370
+
1371
+ // If vector shuffle is supported on the target, this optimization may
1372
+ // increase register pressure.
1373
+ if (TLI.isOperationLegalOrCustomOrPromote(ISD::VECTOR_SHUFFLE, VectorVT))
1374
+ return SDValue();
1375
+
1376
+ SDLoc DL(N);
1377
+
1378
+ SDValue ShufOp = N->getOperand(0);
1379
+ SDValue VectOp = N->getOperand(1);
1380
+ bool Swapped = false;
1381
+
1382
+ // canonicalize shuffle op
1383
+ if (VectOp.getOpcode() == ISD::VECTOR_SHUFFLE) {
1384
+ std::swap(ShufOp, VectOp);
1385
+ Swapped = true;
1386
+ }
1387
+
1388
+ if (ShufOp.getOpcode() != ISD::VECTOR_SHUFFLE)
1389
+ return SDValue();
1390
+
1391
+ auto *ShuffleOp = cast<ShuffleVectorSDNode>(ShufOp);
1392
+ int LiveLane; // exclusively live lane
1393
+ for (LiveLane = 0; LiveLane < 2; ++LiveLane) {
1394
+ // check if the current lane is live and the other lane is dead
1395
+ if (ShuffleOp->getMaskElt(LiveLane) != PoisonMaskElem &&
1396
+ ShuffleOp->getMaskElt(!LiveLane) == PoisonMaskElem)
1397
+ break;
1398
+ }
1399
+ if (LiveLane == 2)
1400
+ return SDValue();
1401
+
1402
+ const int ElementIdx = ShuffleOp->getMaskElt(LiveLane);
1403
+ const EVT ScalarVT = VectorVT.getScalarType();
1404
+ SDValue Lanes[2] = {};
1405
+ for (auto [LaneID, LaneVal] : enumerate(Lanes)) {
1406
+ if (LaneID == (unsigned)LiveLane) {
1407
+ SDValue Operands[2] = {
1408
+ DAG.getExtractVectorElt(DL, ScalarVT, ShufOp.getOperand(0),
1409
+ ElementIdx),
1410
+ DAG.getExtractVectorElt(DL, ScalarVT, VectOp, LiveLane)};
1411
+ // preserve the order of operands
1412
+ if (Swapped)
1413
+ std::swap(Operands[0], Operands[1]);
1414
+ LaneVal = DAG.getNode(N->getOpcode(), DL, ScalarVT, Operands);
1415
+ } else {
1416
+ LaneVal = DAG.getUNDEF(ScalarVT);
1417
+ }
1418
+ }
1419
+ return DAG.getBuildVector(VectorVT, DL, Lanes);
1420
+ }
1421
+
1352
1422
SDValue DAGCombiner::CombineTo(SDNode *N, const SDValue *To, unsigned NumTo,
1353
1423
bool AddTo) {
1354
1424
assert(N->getNumValues() == NumTo && "Broken CombineTo call!");
@@ -3058,6 +3128,9 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
3058
3128
return DAG.getNode(ISD::ADD, DL, VT, N0.getOperand(0), SV);
3059
3129
}
3060
3130
3131
+ if (SDValue R = foldReductionWithUndefLane(N))
3132
+ return R;
3133
+
3061
3134
return SDValue();
3062
3135
}
3063
3136
@@ -6001,6 +6074,9 @@ SDValue DAGCombiner::visitIMINMAX(SDNode *N) {
6001
6074
SDLoc(N), VT, N0, N1))
6002
6075
return SD;
6003
6076
6077
+ if (SDValue SD = foldReductionWithUndefLane(N))
6078
+ return SD;
6079
+
6004
6080
// Simplify the operands using demanded-bits information.
6005
6081
if (SimplifyDemandedBits(SDValue(N, 0)))
6006
6082
return SDValue(N, 0);
@@ -7301,6 +7377,9 @@ SDValue DAGCombiner::visitAND(SDNode *N) {
7301
7377
}
7302
7378
}
7303
7379
}
7380
+
7381
+ if (SDValue R = foldReductionWithUndefLane(N))
7382
+ return R;
7304
7383
}
7305
7384
7306
7385
// fold (and x, -1) -> x
@@ -8260,6 +8339,9 @@ SDValue DAGCombiner::visitOR(SDNode *N) {
8260
8339
}
8261
8340
}
8262
8341
}
8342
+
8343
+ if (SDValue R = foldReductionWithUndefLane(N))
8344
+ return R;
8263
8345
}
8264
8346
8265
8347
// fold (or x, 0) -> x
@@ -9941,6 +10023,9 @@ SDValue DAGCombiner::visitXOR(SDNode *N) {
9941
10023
if (SDValue Combined = combineCarryDiamond(DAG, TLI, N0, N1, N))
9942
10024
return Combined;
9943
10025
10026
+ if (SDValue R = foldReductionWithUndefLane(N))
10027
+ return R;
10028
+
9944
10029
return SDValue();
9945
10030
}
9946
10031
@@ -17557,6 +17642,10 @@ SDValue DAGCombiner::visitFADD(SDNode *N) {
17557
17642
AddToWorklist(Fused.getNode());
17558
17643
return Fused;
17559
17644
}
17645
+
17646
+ if (SDValue R = foldReductionWithUndefLane(N))
17647
+ return R;
17648
+
17560
17649
return SDValue();
17561
17650
}
17562
17651
@@ -17925,6 +18014,9 @@ SDValue DAGCombiner::visitFMUL(SDNode *N) {
17925
18014
if (SDValue R = combineFMulOrFDivWithIntPow2(N))
17926
18015
return R;
17927
18016
18017
+ if (SDValue R = foldReductionWithUndefLane(N))
18018
+ return R;
18019
+
17928
18020
return SDValue();
17929
18021
}
17930
18022
@@ -19030,6 +19122,9 @@ SDValue DAGCombiner::visitFMinMax(SDNode *N) {
19030
19122
Opc, SDLoc(N), VT, N0, N1, Flags))
19031
19123
return SD;
19032
19124
19125
+ if (SDValue SD = foldReductionWithUndefLane(N))
19126
+ return SD;
19127
+
19033
19128
return SDValue();
19034
19129
}
19035
19130
0 commit comments