Skip to content

Commit d14db8c

Browse files
committed
[ARM] Match MVE vqdmulh
This adds ISel matching for a form of VQDMULH. There are several ir patterns that we could match to that instruction, this one is for: min(ashr(mul(sext(a), sext(b)), 7), 127) Which is what llvm will optimize to once it has removed the max that usually makes up the min/max saturate pattern, as in this case the compare will always be false. The additional complication to match i32 patterns (which extend into an i64) is that the min will be a vselect/setcc, as vmin is not supported for i64 vectors. Tablegen patterns have also been updated to attempt to reuse the MVE_TwoOpPattern patterns. Differential Revision: https://reviews.llvm.org/D90096
1 parent 62286c5 commit d14db8c

File tree

4 files changed

+138
-663
lines changed

4 files changed

+138
-663
lines changed

llvm/lib/Target/ARM/ARMISelLowering.cpp

Lines changed: 96 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1718,6 +1718,7 @@ const char *ARMTargetLowering::getTargetNodeName(unsigned Opcode) const {
17181718
case ARMISD::VCVTL: return "ARMISD::VCVTL";
17191719
case ARMISD::VMULLs: return "ARMISD::VMULLs";
17201720
case ARMISD::VMULLu: return "ARMISD::VMULLu";
1721+
case ARMISD::VQDMULH: return "ARMISD::VQDMULH";
17211722
case ARMISD::VADDVs: return "ARMISD::VADDVs";
17221723
case ARMISD::VADDVu: return "ARMISD::VADDVu";
17231724
case ARMISD::VADDVps: return "ARMISD::VADDVps";
@@ -12206,9 +12207,93 @@ static SDValue PerformSELECTCombine(SDNode *N,
1220612207
return Reduction;
1220712208
}
1220812209

12210+
// A special combine for the vqdmulh family of instructions. This is one of the
12211+
// potential set of patterns that could patch this instruction. The base pattern
12212+
// you would expect to be min(max(ashr(mul(mul(sext(x), 2), sext(y)), 16))).
12213+
// This matches the different min(max(ashr(mul(mul(sext(x), sext(y)), 2), 16))),
12214+
// which llvm will have optimized to min(ashr(mul(sext(x), sext(y)), 15))) as
12215+
// the max is unnecessary.
12216+
static SDValue PerformVQDMULHCombine(SDNode *N, SelectionDAG &DAG) {
12217+
EVT VT = N->getValueType(0);
12218+
SDValue Shft;
12219+
ConstantSDNode *Clamp;
12220+
12221+
if (N->getOpcode() == ISD::SMIN) {
12222+
Shft = N->getOperand(0);
12223+
Clamp = isConstOrConstSplat(N->getOperand(1));
12224+
} else if (N->getOpcode() == ISD::VSELECT) {
12225+
// Detect a SMIN, which for an i64 node will be a vselect/setcc, not a smin.
12226+
SDValue Cmp = N->getOperand(0);
12227+
if (Cmp.getOpcode() != ISD::SETCC ||
12228+
cast<CondCodeSDNode>(Cmp.getOperand(2))->get() != ISD::SETLT ||
12229+
Cmp.getOperand(0) != N->getOperand(1) ||
12230+
Cmp.getOperand(1) != N->getOperand(2))
12231+
return SDValue();
12232+
Shft = N->getOperand(1);
12233+
Clamp = isConstOrConstSplat(N->getOperand(2));
12234+
} else
12235+
return SDValue();
12236+
12237+
if (!Clamp)
12238+
return SDValue();
12239+
12240+
MVT ScalarType;
12241+
int ShftAmt = 0;
12242+
switch (Clamp->getSExtValue()) {
12243+
case (1 << 7) - 1:
12244+
ScalarType = MVT::i8;
12245+
ShftAmt = 7;
12246+
break;
12247+
case (1 << 15) - 1:
12248+
ScalarType = MVT::i16;
12249+
ShftAmt = 15;
12250+
break;
12251+
case (1ULL << 31) - 1:
12252+
ScalarType = MVT::i32;
12253+
ShftAmt = 31;
12254+
break;
12255+
default:
12256+
return SDValue();
12257+
}
12258+
12259+
if (Shft.getOpcode() != ISD::SRA)
12260+
return SDValue();
12261+
ConstantSDNode *N1 = isConstOrConstSplat(Shft.getOperand(1));
12262+
if (!N1 || N1->getSExtValue() != ShftAmt)
12263+
return SDValue();
12264+
12265+
SDValue Mul = Shft.getOperand(0);
12266+
if (Mul.getOpcode() != ISD::MUL)
12267+
return SDValue();
12268+
12269+
SDValue Ext0 = Mul.getOperand(0);
12270+
SDValue Ext1 = Mul.getOperand(1);
12271+
if (Ext0.getOpcode() != ISD::SIGN_EXTEND ||
12272+
Ext1.getOpcode() != ISD::SIGN_EXTEND)
12273+
return SDValue();
12274+
EVT VecVT = Ext0.getOperand(0).getValueType();
12275+
if (VecVT != MVT::v4i32 && VecVT != MVT::v8i16 && VecVT != MVT::v16i8)
12276+
return SDValue();
12277+
if (Ext1.getOperand(0).getValueType() != VecVT ||
12278+
VecVT.getScalarType() != ScalarType ||
12279+
VT.getScalarSizeInBits() < ScalarType.getScalarSizeInBits() * 2)
12280+
return SDValue();
12281+
12282+
SDLoc DL(Mul);
12283+
SDValue VQDMULH = DAG.getNode(ARMISD::VQDMULH, DL, VecVT, Ext0.getOperand(0),
12284+
Ext1.getOperand(0));
12285+
return DAG.getNode(ISD::SIGN_EXTEND, DL, VT, VQDMULH);
12286+
}
12287+
1220912288
static SDValue PerformVSELECTCombine(SDNode *N,
1221012289
TargetLowering::DAGCombinerInfo &DCI,
1221112290
const ARMSubtarget *Subtarget) {
12291+
if (!Subtarget->hasMVEIntegerOps())
12292+
return SDValue();
12293+
12294+
if (SDValue V = PerformVQDMULHCombine(N, DCI.DAG))
12295+
return V;
12296+
1221212297
// Transforms vselect(not(cond), lhs, rhs) into vselect(cond, rhs, lhs).
1221312298
//
1221412299
// We need to re-implement this optimization here as the implementation in the
@@ -12218,9 +12303,6 @@ static SDValue PerformVSELECTCombine(SDNode *N,
1221812303
//
1221912304
// Currently, this is only done for MVE, as it's the only target that benefits
1222012305
// from this transformation (e.g. VPNOT+VPSEL becomes a single VPSEL).
12221-
if (!Subtarget->hasMVEIntegerOps())
12222-
return SDValue();
12223-
1222412306
if (N->getOperand(0).getOpcode() != ISD::XOR)
1222512307
return SDValue();
1222612308
SDValue XOR = N->getOperand(0);
@@ -14582,6 +14664,14 @@ static SDValue PerformSplittingToNarrowingStores(StoreSDNode *St,
1458214664
return true;
1458314665
};
1458414666

14667+
// It may be preferable to keep the store unsplit as the trunc may end up
14668+
// being removed. Check that here.
14669+
if (Trunc.getOperand(0).getOpcode() == ISD::SMIN) {
14670+
if (SDValue U = PerformVQDMULHCombine(Trunc.getOperand(0).getNode(), DAG)) {
14671+
DAG.ReplaceAllUsesWith(Trunc.getOperand(0), U);
14672+
return SDValue();
14673+
}
14674+
}
1458514675
if (auto *Shuffle = dyn_cast<ShuffleVectorSDNode>(Trunc->getOperand(0)))
1458614676
if (isVMOVNOriginalMask(Shuffle->getMask(), false) ||
1458714677
isVMOVNOriginalMask(Shuffle->getMask(), true))
@@ -15555,6 +15645,9 @@ static SDValue PerformMinMaxCombine(SDNode *N, SelectionDAG &DAG,
1555515645
if (!ST->hasMVEIntegerOps())
1555615646
return SDValue();
1555715647

15648+
if (SDValue V = PerformVQDMULHCombine(N, DAG))
15649+
return V;
15650+
1555815651
if (VT != MVT::v4i32 && VT != MVT::v8i16)
1555915652
return SDValue();
1556015653

llvm/lib/Target/ARM/ARMISelLowering.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -216,6 +216,8 @@ class VectorType;
216216
VMULLs, // ...signed
217217
VMULLu, // ...unsigned
218218

219+
VQDMULH, // MVE vqdmulh instruction
220+
219221
// MVE reductions
220222
VADDVs, // sign- or zero-extend the elements of a vector to i32,
221223
VADDVu, // add them all together, and return an i32 of their sum

llvm/lib/Target/ARM/ARMInstrMVE.td

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1955,28 +1955,26 @@ class MVE_VQxDMULH_Base<string iname, string suffix, bits<2> size, bit rounding,
19551955
let validForTailPredication = 1;
19561956
}
19571957

1958+
def MVEvqdmulh : SDNode<"ARMISD::VQDMULH", SDTIntBinOp>;
1959+
19581960
multiclass MVE_VQxDMULH_m<string iname, MVEVectorVTInfo VTI,
1959-
SDNode unpred_op, Intrinsic pred_int,
1961+
SDNode Op, Intrinsic unpred_int, Intrinsic pred_int,
19601962
bit rounding> {
19611963
def "" : MVE_VQxDMULH_Base<iname, VTI.Suffix, VTI.Size, rounding>;
19621964
defvar Inst = !cast<Instruction>(NAME);
1965+
defm : MVE_TwoOpPattern<VTI, Op, pred_int, (? ), Inst>;
19631966

19641967
let Predicates = [HasMVEInt] in {
1965-
// Unpredicated multiply
1966-
def : Pat<(VTI.Vec (unpred_op (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn))),
1968+
// Extra unpredicated multiply intrinsic patterns
1969+
def : Pat<(VTI.Vec (unpred_int (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn))),
19671970
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn)))>;
1968-
1969-
// Predicated multiply
1970-
def : Pat<(VTI.Vec (pred_int (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn),
1971-
(VTI.Pred VCCR:$mask), (VTI.Vec MQPR:$inactive))),
1972-
(VTI.Vec (Inst (VTI.Vec MQPR:$Qm), (VTI.Vec MQPR:$Qn),
1973-
ARMVCCThen, (VTI.Pred VCCR:$mask),
1974-
(VTI.Vec MQPR:$inactive)))>;
19751971
}
19761972
}
19771973

19781974
multiclass MVE_VQxDMULH<string iname, MVEVectorVTInfo VTI, bit rounding>
1979-
: MVE_VQxDMULH_m<iname, VTI, !if(rounding, int_arm_mve_vqrdmulh,
1975+
: MVE_VQxDMULH_m<iname, VTI, !if(rounding, null_frag,
1976+
MVEvqdmulh),
1977+
!if(rounding, int_arm_mve_vqrdmulh,
19801978
int_arm_mve_vqdmulh),
19811979
!if(rounding, int_arm_mve_qrdmulh_predicated,
19821980
int_arm_mve_qdmulh_predicated),
@@ -5492,18 +5490,18 @@ class MVE_VxxMUL_qr<string iname, string suffix,
54925490
}
54935491

54945492
multiclass MVE_VxxMUL_qr_m<string iname, MVEVectorVTInfo VTI, bit bit_28,
5495-
Intrinsic int_unpred, Intrinsic int_pred> {
5493+
PatFrag Op, Intrinsic int_unpred, Intrinsic int_pred> {
54965494
def "" : MVE_VxxMUL_qr<iname, VTI.Suffix, bit_28, VTI.Size>;
5497-
defm : MVE_vec_scalar_int_pat_m<!cast<Instruction>(NAME), VTI,
5498-
int_unpred, int_pred>;
5495+
defm : MVE_TwoOpPatternDup<VTI, Op, int_pred, (? ), !cast<Instruction>(NAME)>;
5496+
defm : MVE_vec_scalar_int_pat_m<!cast<Instruction>(NAME), VTI, int_unpred, int_pred>;
54995497
}
55005498

55015499
multiclass MVE_VQDMULH_qr_m<MVEVectorVTInfo VTI> :
5502-
MVE_VxxMUL_qr_m<"vqdmulh", VTI, 0b0,
5500+
MVE_VxxMUL_qr_m<"vqdmulh", VTI, 0b0, MVEvqdmulh,
55035501
int_arm_mve_vqdmulh, int_arm_mve_qdmulh_predicated>;
55045502

55055503
multiclass MVE_VQRDMULH_qr_m<MVEVectorVTInfo VTI> :
5506-
MVE_VxxMUL_qr_m<"vqrdmulh", VTI, 0b1,
5504+
MVE_VxxMUL_qr_m<"vqrdmulh", VTI, 0b1, null_frag,
55075505
int_arm_mve_vqrdmulh, int_arm_mve_qrdmulh_predicated>;
55085506

55095507
defm MVE_VQDMULH_qr_s8 : MVE_VQDMULH_qr_m<MVE_v16s8>;

0 commit comments

Comments
 (0)