@@ -1718,6 +1718,7 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
1718
1718
case ARMISD::VCVTL: return "ARMISD::VCVTL";
1719
1719
case ARMISD::VMULLs: return "ARMISD::VMULLs";
1720
1720
case ARMISD::VMULLu: return "ARMISD::VMULLu";
1721
+ case ARMISD::VQDMULH: return "ARMISD::VQDMULH";
1721
1722
case ARMISD::VADDVs: return "ARMISD::VADDVs";
1722
1723
case ARMISD::VADDVu: return "ARMISD::VADDVu";
1723
1724
case ARMISD::VADDVps: return "ARMISD::VADDVps";
@@ -12206,9 +12207,93 @@ static SDValue PerformSELECTCombine(SDNode *N,
12206
12207
return Reduction;
12207
12208
}
12208
12209
12210
+ // A special combine for the vqdmulh family of instructions. This is one of the
12211
+ // potential set of patterns that could patch this instruction. The base pattern
12212
+ // you would expect to be min(max(ashr(mul(mul(sext(x), 2), sext(y)), 16))).
12213
+ // This matches the different min(max(ashr(mul(mul(sext(x), sext(y)), 2), 16))),
12214
+ // which llvm will have optimized to min(ashr(mul(sext(x), sext(y)), 15))) as
12215
+ // the max is unnecessary.
12216
+ static SDValue PerformVQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
12217
+ EVT VT = N->getValueType(0);
12218
+ SDValue Shft;
12219
+ ConstantSDNode *Clamp;
12220
+
12221
+ if (N->getOpcode() == ISD::SMIN) {
12222
+ Shft = N->getOperand(0);
12223
+ Clamp = isConstOrConstSplat(N->getOperand(1));
12224
+ } else if (N->getOpcode() == ISD::VSELECT) {
12225
+ // Detect a SMIN, which for an i64 node will be a vselect/setcc, not a smin.
12226
+ SDValue Cmp = N->getOperand(0);
12227
+ if (Cmp.getOpcode() != ISD::SETCC ||
12228
+ cast<CondCodeSDNode>(Cmp.getOperand(2))->get() != ISD::SETLT ||
12229
+ Cmp.getOperand(0) != N->getOperand(1) ||
12230
+ Cmp.getOperand(1) != N->getOperand(2))
12231
+ return SDValue();
12232
+ Shft = N->getOperand(1);
12233
+ Clamp = isConstOrConstSplat(N->getOperand(2));
12234
+ } else
12235
+ return SDValue();
12236
+
12237
+ if (!Clamp)
12238
+ return SDValue();
12239
+
12240
+ MVT ScalarType;
12241
+ int ShftAmt = 0;
12242
+ switch (Clamp->getSExtValue()) {
12243
+ case (1 << 7) - 1:
12244
+ ScalarType = MVT::i8;
12245
+ ShftAmt = 7;
12246
+ break;
12247
+ case (1 << 15) - 1:
12248
+ ScalarType = MVT::i16;
12249
+ ShftAmt = 15;
12250
+ break;
12251
+ case (1ULL << 31) - 1:
12252
+ ScalarType = MVT::i32;
12253
+ ShftAmt = 31;
12254
+ break;
12255
+ default:
12256
+ return SDValue();
12257
+ }
12258
+
12259
+ if (Shft.getOpcode() != ISD::SRA)
12260
+ return SDValue();
12261
+ ConstantSDNode *N1 = isConstOrConstSplat(Shft.getOperand(1));
12262
+ if (!N1 || N1->getSExtValue() != ShftAmt)
12263
+ return SDValue();
12264
+
12265
+ SDValue Mul = Shft.getOperand(0);
12266
+ if (Mul.getOpcode() != ISD::MUL)
12267
+ return SDValue();
12268
+
12269
+ SDValue Ext0 = Mul.getOperand(0);
12270
+ SDValue Ext1 = Mul.getOperand(1);
12271
+ if (Ext0.getOpcode() != ISD::SIGN_EXTEND ||
12272
+ Ext1.getOpcode() != ISD::SIGN_EXTEND)
12273
+ return SDValue();
12274
+ EVT VecVT = Ext0.getOperand(0).getValueType();
12275
+ if (VecVT != MVT::v4i32 && VecVT != MVT::v8i16 && VecVT != MVT::v16i8)
12276
+ return SDValue();
12277
+ if (Ext1.getOperand(0).getValueType() != VecVT ||
12278
+ VecVT.getScalarType() != ScalarType ||
12279
+ VT.getScalarSizeInBits() < ScalarType.getScalarSizeInBits() * 2)
12280
+ return SDValue();
12281
+
12282
+ SDLoc DL(Mul);
12283
+ SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, VecVT, Ext0.getOperand(0),
12284
+ Ext1.getOperand(0));
12285
+ return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, VQDMULH);
12286
+ }
12287
+
12209
12288
static SDValue PerformVSELECTCombine(SDNode *N,
12210
12289
TargetLowering::DAGCombinerInfo &DCI,
12211
12290
const ARMSubtarget *Subtarget) {
12291
+ if (!Subtarget->hasMVEIntegerOps())
12292
+ return SDValue();
12293
+
12294
+ if (SDValue V = PerformVQDMULHCombine(N, DCI.DAG))
12295
+ return V;
12296
+
12212
12297
// Transforms vselect(not(cond), lhs, rhs) into vselect(cond, rhs, lhs).
12213
12298
//
12214
12299
// We need to re-implement this optimization here as the implementation in the
@@ -12218,9 +12303,6 @@ static SDValue PerformVSELECTCombine(SDNode *N,
12218
12303
//
12219
12304
// Currently, this is only done for MVE, as it's the only target that benefits
12220
12305
// from this transformation (e.g. VPNOT+VPSEL becomes a single VPSEL).
12221
- if (!Subtarget->hasMVEIntegerOps())
12222
- return SDValue();
12223
-
12224
12306
if (N->getOperand(0).getOpcode() != ISD::XOR)
12225
12307
return SDValue();
12226
12308
SDValue XOR = N->getOperand(0);
@@ -14582,6 +14664,14 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St,
14582
14664
return true;
14583
14665
};
14584
14666
14667
+ // It may be preferable to keep the store unsplit as the trunc may end up
14668
+ // being removed. Check that here.
14669
+ if (Trunc.getOperand(0).getOpcode() == ISD::SMIN) {
14670
+ if (SDValue U = PerformVQDMULHCombine(Trunc.getOperand(0).getNode(), DAG)) {
14671
+ DAG.ReplaceAllUsesWith(Trunc.getOperand(0), U);
14672
+ return SDValue();
14673
+ }
14674
+ }
14585
14675
if (auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Trunc->getOperand(0)))
14586
14676
if (isVMOVNOriginalMask(Shuffle->getMask(), false) ||
14587
14677
isVMOVNOriginalMask(Shuffle->getMask(), true))
@@ -15555,6 +15645,9 @@ static SDValue PerformMinMaxCombine(SDNode *N, SelectionDAG &DAG,
15555
15645
if (!ST->hasMVEIntegerOps())
15556
15646
return SDValue();
15557
15647
15648
+ if (SDValue V = PerformVQDMULHCombine(N, DAG))
15649
+ return V;
15650
+
15558
15651
if (VT != MVT::v4i32 && VT != MVT::v8i16)
15559
15652
return SDValue();
15560
15653
0 commit comments