Skip to content

Commit cdc7864

Browse files
committed
[SystemZ] Optimize widening and high-word vector multiplication
Detect (non-intrinsic) IR patterns corresponding to the semantics of the various widening and high-word multiplication instructions. Specifically, this is done by: - Recognizing even/odd widening multiplication patterns in DAGCombine - Recognizing widening multiply-and-add on top during ISel - Implementing the standard MULHS/MUHLU IR opcodes - Detecting high-word multiply-and-add (which common code does not) Depending on architecture level, this can support all integer vector types as well as the scalar i128 type. Fixes: llvm#129705
1 parent 7af3d39 commit cdc7864

File tree

13 files changed

+1239
-83
lines changed

13 files changed

+1239
-83
lines changed

llvm/lib/Target/SystemZ/SystemZISelLowering.cpp

Lines changed: 298 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,8 +454,11 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
454454
setOperationAction(ISD::INSERT_VECTOR_ELT, VT, Legal);
455455
setOperationAction(ISD::ADD, VT, Legal);
456456
setOperationAction(ISD::SUB, VT, Legal);
457-
if (VT != MVT::v2i64 || Subtarget.hasVectorEnhancements3())
457+
if (VT != MVT::v2i64 || Subtarget.hasVectorEnhancements3()) {
458458
setOperationAction(ISD::MUL, VT, Legal);
459+
setOperationAction(ISD::MULHS, VT, Legal);
460+
setOperationAction(ISD::MULHU, VT, Legal);
461+
}
459462
if (Subtarget.hasVectorEnhancements3() &&
460463
VT != MVT::v16i8 && VT != MVT::v8i16) {
461464
setOperationAction(ISD::SDIV, VT, Legal);
@@ -775,6 +778,9 @@ SystemZTargetLowering::SystemZTargetLowering(const TargetMachine &TM,
775778
ISD::STRICT_FP_EXTEND,
776779
ISD::BSWAP,
777780
ISD::SETCC,
781+
ISD::SRL,
782+
ISD::SRA,
783+
ISD::MUL,
778784
ISD::SDIV,
779785
ISD::UDIV,
780786
ISD::SREM,
@@ -5345,6 +5351,94 @@ SystemZTargetLowering::lowerINTRINSIC_WO_CHAIN(SDValue Op,
53455351
case Intrinsic::s390_vsbcbiq:
53465352
return DAG.getNode(SystemZISD::VSBCBI, SDLoc(Op), Op.getValueType(),
53475353
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5354+
5355+
case Intrinsic::s390_vmhb:
5356+
case Intrinsic::s390_vmhh:
5357+
case Intrinsic::s390_vmhf:
5358+
case Intrinsic::s390_vmhg:
5359+
case Intrinsic::s390_vmhq:
5360+
return DAG.getNode(ISD::MULHS, SDLoc(Op), Op.getValueType(),
5361+
Op.getOperand(1), Op.getOperand(2));
5362+
case Intrinsic::s390_vmlhb:
5363+
case Intrinsic::s390_vmlhh:
5364+
case Intrinsic::s390_vmlhf:
5365+
case Intrinsic::s390_vmlhg:
5366+
case Intrinsic::s390_vmlhq:
5367+
return DAG.getNode(ISD::MULHU, SDLoc(Op), Op.getValueType(),
5368+
Op.getOperand(1), Op.getOperand(2));
5369+
5370+
case Intrinsic::s390_vmahb:
5371+
case Intrinsic::s390_vmahh:
5372+
case Intrinsic::s390_vmahf:
5373+
case Intrinsic::s390_vmahg:
5374+
case Intrinsic::s390_vmahq:
5375+
return DAG.getNode(SystemZISD::VMAH, SDLoc(Op), Op.getValueType(),
5376+
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5377+
case Intrinsic::s390_vmalhb:
5378+
case Intrinsic::s390_vmalhh:
5379+
case Intrinsic::s390_vmalhf:
5380+
case Intrinsic::s390_vmalhg:
5381+
case Intrinsic::s390_vmalhq:
5382+
return DAG.getNode(SystemZISD::VMALH, SDLoc(Op), Op.getValueType(),
5383+
Op.getOperand(1), Op.getOperand(2), Op.getOperand(3));
5384+
5385+
case Intrinsic::s390_vmeb:
5386+
case Intrinsic::s390_vmeh:
5387+
case Intrinsic::s390_vmef:
5388+
case Intrinsic::s390_vmeg:
5389+
return DAG.getNode(SystemZISD::VME, SDLoc(Op), Op.getValueType(),
5390+
Op.getOperand(1), Op.getOperand(2));
5391+
case Intrinsic::s390_vmleb:
5392+
case Intrinsic::s390_vmleh:
5393+
case Intrinsic::s390_vmlef:
5394+
case Intrinsic::s390_vmleg:
5395+
return DAG.getNode(SystemZISD::VMLE, SDLoc(Op), Op.getValueType(),
5396+
Op.getOperand(1), Op.getOperand(2));
5397+
case Intrinsic::s390_vmob:
5398+
case Intrinsic::s390_vmoh:
5399+
case Intrinsic::s390_vmof:
5400+
case Intrinsic::s390_vmog:
5401+
return DAG.getNode(SystemZISD::VMO, SDLoc(Op), Op.getValueType(),
5402+
Op.getOperand(1), Op.getOperand(2));
5403+
case Intrinsic::s390_vmlob:
5404+
case Intrinsic::s390_vmloh:
5405+
case Intrinsic::s390_vmlof:
5406+
case Intrinsic::s390_vmlog:
5407+
return DAG.getNode(SystemZISD::VMLO, SDLoc(Op), Op.getValueType(),
5408+
Op.getOperand(1), Op.getOperand(2));
5409+
5410+
case Intrinsic::s390_vmaeb:
5411+
case Intrinsic::s390_vmaeh:
5412+
case Intrinsic::s390_vmaef:
5413+
case Intrinsic::s390_vmaeg:
5414+
return DAG.getNode(ISD::ADD, SDLoc(Op), Op.getValueType(),
5415+
DAG.getNode(SystemZISD::VME, SDLoc(Op), Op.getValueType(),
5416+
Op.getOperand(1), Op.getOperand(2)),
5417+
Op.getOperand(3));
5418+
case Intrinsic::s390_vmaleb:
5419+
case Intrinsic::s390_vmaleh:
5420+
case Intrinsic::s390_vmalef:
5421+
case Intrinsic::s390_vmaleg:
5422+
return DAG.getNode(ISD::ADD, SDLoc(Op), Op.getValueType(),
5423+
DAG.getNode(SystemZISD::VMLE, SDLoc(Op), Op.getValueType(),
5424+
Op.getOperand(1), Op.getOperand(2)),
5425+
Op.getOperand(3));
5426+
case Intrinsic::s390_vmaob:
5427+
case Intrinsic::s390_vmaoh:
5428+
case Intrinsic::s390_vmaof:
5429+
case Intrinsic::s390_vmaog:
5430+
return DAG.getNode(ISD::ADD, SDLoc(Op), Op.getValueType(),
5431+
DAG.getNode(SystemZISD::VMO, SDLoc(Op), Op.getValueType(),
5432+
Op.getOperand(1), Op.getOperand(2)),
5433+
Op.getOperand(3));
5434+
case Intrinsic::s390_vmalob:
5435+
case Intrinsic::s390_vmaloh:
5436+
case Intrinsic::s390_vmalof:
5437+
case Intrinsic::s390_vmalog:
5438+
return DAG.getNode(ISD::ADD, SDLoc(Op), Op.getValueType(),
5439+
DAG.getNode(SystemZISD::VMLO, SDLoc(Op), Op.getValueType(),
5440+
Op.getOperand(1), Op.getOperand(2)),
5441+
Op.getOperand(3));
53485442
}
53495443

53505444
return SDValue();
@@ -6912,6 +7006,12 @@ const char *SystemZTargetLowering::getTargetNodeName(unsigned Opcode) const {
69127006
OPCODE(VSBI);
69137007
OPCODE(VACCC);
69147008
OPCODE(VSBCBI);
7009+
OPCODE(VMAH);
7010+
OPCODE(VMALH);
7011+
OPCODE(VME);
7012+
OPCODE(VMLE);
7013+
OPCODE(VMO);
7014+
OPCODE(VMLO);
69157015
OPCODE(VICMPE);
69167016
OPCODE(VICMPH);
69177017
OPCODE(VICMPHL);
@@ -8311,6 +8411,200 @@ SDValue SystemZTargetLowering::combineIntDIVREM(
83118411
return SDValue();
83128412
}
83138413

8414+
8415+
// Transform a right shift of a multiply-and-add into a multiply-and-add-high.
8416+
// This is closely modeled after the common-code combineShiftToMULH.
8417+
SDValue SystemZTargetLowering::combineShiftToMulAddHigh(
8418+
SDNode *N, DAGCombinerInfo &DCI) const {
8419+
SelectionDAG &DAG = DCI.DAG;
8420+
SDLoc DL(N);
8421+
8422+
assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
8423+
"SRL or SRA node is required here!");
8424+
8425+
if (!Subtarget.hasVector())
8426+
return SDValue();
8427+
8428+
// Check the shift amount. Proceed with the transformation if the shift
8429+
// amount is constant.
8430+
ConstantSDNode *ShiftAmtSrc = isConstOrConstSplat(N->getOperand(1));
8431+
if (!ShiftAmtSrc)
8432+
return SDValue();
8433+
8434+
// The operation feeding into the shift must be an add.
8435+
SDValue ShiftOperand = N->getOperand(0);
8436+
if (ShiftOperand.getOpcode() != ISD::ADD)
8437+
return SDValue();
8438+
8439+
// One operand of the add must be a multiply.
8440+
SDValue MulOp = ShiftOperand.getOperand(0);
8441+
SDValue AddOp = ShiftOperand.getOperand(1);
8442+
if (MulOp.getOpcode() != ISD::MUL) {
8443+
if (AddOp.getOpcode() != ISD::MUL)
8444+
return SDValue();
8445+
std::swap(MulOp, AddOp);
8446+
}
8447+
8448+
// All operands must be equivalent extend nodes.
8449+
SDValue LeftOp = MulOp.getOperand(0);
8450+
SDValue RightOp = MulOp.getOperand(1);
8451+
8452+
bool IsSignExt = LeftOp.getOpcode() == ISD::SIGN_EXTEND;
8453+
bool IsZeroExt = LeftOp.getOpcode() == ISD::ZERO_EXTEND;
8454+
8455+
if (!IsSignExt && !IsZeroExt)
8456+
return SDValue();
8457+
8458+
EVT NarrowVT = LeftOp.getOperand(0).getValueType();
8459+
unsigned NarrowVTSize = NarrowVT.getScalarSizeInBits();
8460+
8461+
SDValue MulhRightOp;
8462+
if (ConstantSDNode *Constant = isConstOrConstSplat(RightOp)) {
8463+
unsigned ActiveBits = IsSignExt
8464+
? Constant->getAPIntValue().getSignificantBits()
8465+
: Constant->getAPIntValue().getActiveBits();
8466+
if (ActiveBits > NarrowVTSize)
8467+
return SDValue();
8468+
MulhRightOp = DAG.getConstant(
8469+
Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
8470+
NarrowVT);
8471+
} else {
8472+
if (LeftOp.getOpcode() != RightOp.getOpcode())
8473+
return SDValue();
8474+
// Check that the two extend nodes are the same type.
8475+
if (NarrowVT != RightOp.getOperand(0).getValueType())
8476+
return SDValue();
8477+
MulhRightOp = RightOp.getOperand(0);
8478+
}
8479+
8480+
SDValue MulhAddOp;
8481+
if (ConstantSDNode *Constant = isConstOrConstSplat(AddOp)) {
8482+
unsigned ActiveBits = IsSignExt
8483+
? Constant->getAPIntValue().getSignificantBits()
8484+
: Constant->getAPIntValue().getActiveBits();
8485+
if (ActiveBits > NarrowVTSize)
8486+
return SDValue();
8487+
MulhAddOp = DAG.getConstant(
8488+
Constant->getAPIntValue().trunc(NarrowVT.getScalarSizeInBits()), DL,
8489+
NarrowVT);
8490+
} else {
8491+
if (LeftOp.getOpcode() != AddOp.getOpcode())
8492+
return SDValue();
8493+
// Check that the two extend nodes are the same type.
8494+
if (NarrowVT != AddOp.getOperand(0).getValueType())
8495+
return SDValue();
8496+
MulhAddOp = AddOp.getOperand(0);
8497+
}
8498+
8499+
EVT WideVT = LeftOp.getValueType();
8500+
// Proceed with the transformation if the wide types match.
8501+
assert((WideVT == RightOp.getValueType()) &&
8502+
"Cannot have a multiply node with two different operand types.");
8503+
assert((WideVT == AddOp.getValueType()) &&
8504+
"Cannot have an add node with two different operand types.");
8505+
8506+
// Proceed with the transformation if the wide type is twice as large
8507+
// as the narrow type.
8508+
if (WideVT.getScalarSizeInBits() != 2 * NarrowVTSize)
8509+
return SDValue();
8510+
8511+
// Check the shift amount with the narrow type size.
8512+
// Proceed with the transformation if the shift amount is the width
8513+
// of the narrow type.
8514+
unsigned ShiftAmt = ShiftAmtSrc->getZExtValue();
8515+
if (ShiftAmt != NarrowVTSize)
8516+
return SDValue();
8517+
8518+
// Proceed if we support the multiply-and-add-high operation.
8519+
if (!(NarrowVT == MVT::v16i8 || NarrowVT == MVT::v8i16 ||
8520+
NarrowVT == MVT::v4i32 ||
8521+
(Subtarget.hasVectorEnhancements3() &&
8522+
(NarrowVT == MVT::v2i64 || NarrowVT == MVT::i128))))
8523+
return SDValue();
8524+
8525+
// Emit the VMAH (signed) or VMALH (unsigned) operation.
8526+
SDValue Result = DAG.getNode(IsSignExt ? SystemZISD::VMAH : SystemZISD::VMALH,
8527+
DL, NarrowVT, LeftOp.getOperand(0),
8528+
MulhRightOp, MulhAddOp);
8529+
bool IsSigned = N->getOpcode() == ISD::SRA;
8530+
return DAG.getExtOrTrunc(IsSigned, Result, DL, WideVT);
8531+
}
8532+
8533+
// Op is an operand of a multiplication. Check whether this can be folded
8534+
// into an even/odd widening operation; if so, return the opcode to be used
8535+
// and update Op to the appropriate sub-operand. Note that the caller must
8536+
// verify that *both* operands of the multiplication support the operation.
8537+
static unsigned detectEvenOddMultiplyOperand(const SelectionDAG &DAG,
8538+
const SystemZSubtarget &Subtarget,
8539+
SDValue &Op) {
8540+
EVT VT = Op.getValueType();
8541+
8542+
// Check for (sign/zero_extend_vector_inreg (vector_shuffle)) corresponding
8543+
// to selecting the even or odd vector elements.
8544+
if (VT.isVector() && DAG.getTargetLoweringInfo().isTypeLegal(VT) &&
8545+
(Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG ||
8546+
Op.getOpcode() == ISD::ZERO_EXTEND_VECTOR_INREG)) {
8547+
bool IsSigned = Op.getOpcode() == ISD::SIGN_EXTEND_VECTOR_INREG;
8548+
unsigned NumElts = VT.getVectorNumElements();
8549+
Op = Op.getOperand(0);
8550+
if (Op.getValueType().getVectorNumElements() == 2 * NumElts &&
8551+
Op.getOpcode() == ISD::VECTOR_SHUFFLE) {
8552+
ShuffleVectorSDNode *SVN = cast<ShuffleVectorSDNode>(Op.getNode());
8553+
ArrayRef<int> ShuffleMask = SVN->getMask();
8554+
bool CanUseEven = true, CanUseOdd = true;
8555+
for (unsigned Elt = 0; Elt < NumElts; Elt++) {
8556+
if (ShuffleMask[Elt] == -1)
8557+
continue;
8558+
if (unsigned(ShuffleMask[Elt]) != 2 * Elt)
8559+
CanUseEven = false;
8560+
if (unsigned(ShuffleMask[Elt]) != 2 * Elt + 1)
8561+
CanUseEven = true;
8562+
}
8563+
Op = Op.getOperand(0);
8564+
if (CanUseEven)
8565+
return IsSigned ? SystemZISD::VME : SystemZISD::VMLE;
8566+
if (CanUseOdd)
8567+
return IsSigned ? SystemZISD::VMO : SystemZISD::VMLO;
8568+
}
8569+
}
8570+
8571+
// For arch15, we can also support the v2i64->i128 case, which looks like
8572+
// (sign/zero_extend (extract_vector_elt X 0/1))
8573+
if (VT == MVT::i128 && Subtarget.hasVectorEnhancements3() &&
8574+
(Op.getOpcode() == ISD::SIGN_EXTEND ||
8575+
Op.getOpcode() == ISD::ZERO_EXTEND)) {
8576+
bool IsSigned = Op.getOpcode() == ISD::SIGN_EXTEND;
8577+
Op = Op.getOperand(0);
8578+
if (Op.getOpcode() == ISD::EXTRACT_VECTOR_ELT &&
8579+
Op.getOperand(0).getValueType() == MVT::v2i64 &&
8580+
Op.getOperand(1).getOpcode() == ISD::Constant) {
8581+
unsigned Elem = Op.getConstantOperandVal(1);
8582+
Op = Op.getOperand(0);
8583+
if (Elem == 0)
8584+
return IsSigned ? SystemZISD::VME : SystemZISD::VMLE;
8585+
if (Elem == 1)
8586+
return IsSigned ? SystemZISD::VMO : SystemZISD::VMLO;
8587+
}
8588+
}
8589+
8590+
return 0;
8591+
}
8592+
8593+
SDValue SystemZTargetLowering::combineMUL(
8594+
SDNode *N, DAGCombinerInfo &DCI) const {
8595+
SelectionDAG &DAG = DCI.DAG;
8596+
8597+
// Detect even/odd widening multiplication.
8598+
SDValue Op0 = N->getOperand(0);
8599+
SDValue Op1 = N->getOperand(1);
8600+
unsigned OpcodeCand0 = detectEvenOddMultiplyOperand(DAG, Subtarget, Op0);
8601+
unsigned OpcodeCand1 = detectEvenOddMultiplyOperand(DAG, Subtarget, Op1);
8602+
if (OpcodeCand0 && OpcodeCand0 == OpcodeCand1)
8603+
return DAG.getNode(OpcodeCand0, SDLoc(N), N->getValueType(0), Op0, Op1);
8604+
8605+
return SDValue();
8606+
}
8607+
83148608
SDValue SystemZTargetLowering::combineINTRINSIC(
83158609
SDNode *N, DAGCombinerInfo &DCI) const {
83168610
SelectionDAG &DAG = DCI.DAG;
@@ -8370,6 +8664,9 @@ SDValue SystemZTargetLowering::PerformDAGCombine(SDNode *N,
83708664
case SystemZISD::BR_CCMASK: return combineBR_CCMASK(N, DCI);
83718665
case SystemZISD::SELECT_CCMASK: return combineSELECT_CCMASK(N, DCI);
83728666
case SystemZISD::GET_CCMASK: return combineGET_CCMASK(N, DCI);
8667+
case ISD::SRL:
8668+
case ISD::SRA: return combineShiftToMulAddHigh(N, DCI);
8669+
case ISD::MUL: return combineMUL(N, DCI);
83738670
case ISD::SDIV:
83748671
case ISD::UDIV:
83758672
case ISD::SREM:

llvm/lib/Target/SystemZ/SystemZISelLowering.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,11 @@ enum NodeType : unsigned {
234234
// Compute carry/borrow indication for add/subtract with carry/borrow.
235235
VACCC, VSBCBI,
236236

237+
// High-word multiply-and-add.
238+
VMAH, VMALH,
239+
// Widen and multiply even/odd vector elements.
240+
VME, VMLE, VMO, VMLO,
241+
237242
// Compare integer vector operands 0 and 1 to produce the usual 0/-1
238243
// vector result. VICMPE is for equality, VICMPH for "signed greater than"
239244
// and VICMPHL for "unsigned greater than".
@@ -759,6 +764,8 @@ class SystemZTargetLowering : public TargetLowering {
759764
SDValue combineBR_CCMASK(SDNode *N, DAGCombinerInfo &DCI) const;
760765
SDValue combineSELECT_CCMASK(SDNode *N, DAGCombinerInfo &DCI) const;
761766
SDValue combineGET_CCMASK(SDNode *N, DAGCombinerInfo &DCI) const;
767+
SDValue combineShiftToMulAddHigh(SDNode *N, DAGCombinerInfo &DCI) const;
768+
SDValue combineMUL(SDNode *N, DAGCombinerInfo &DCI) const;
762769
SDValue combineIntDIVREM(SDNode *N, DAGCombinerInfo &DCI) const;
763770
SDValue combineINTRINSIC(SDNode *N, DAGCombinerInfo &DCI) const;
764771

0 commit comments

Comments
 (0)