Skip to content

Commit 391dafd

Browse files
authored
[RISCV] Consolidate both copies of getLMUL1VT [nfc] (#144568)
Put one copy on RISCVTargetLowering as a static function so that both locations can use it, and rename the method to getM1VT for slightly improved readability.
1 parent 80f3a28 commit 391dafd

File tree

3 files changed

+39
-45
lines changed

3 files changed

+39
-45
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 29 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3498,14 +3498,6 @@ getVSlideup(SelectionDAG &DAG, const RISCVSubtarget &Subtarget, const SDLoc &DL,
34983498
return DAG.getNode(RISCVISD::VSLIDEUP_VL, DL, VT, Ops);
34993499
}
35003500

3501-
static MVT getLMUL1VT(MVT VT) {
3502-
assert(VT.getVectorElementType().getSizeInBits() <= RISCV::RVVBitsPerBlock &&
3503-
"Unexpected vector MVT");
3504-
return MVT::getScalableVectorVT(
3505-
VT.getVectorElementType(),
3506-
RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
3507-
}
3508-
35093501
struct VIDSequence {
35103502
int64_t StepNumerator;
35113503
unsigned StepDenominator;
@@ -4316,7 +4308,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
43164308
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
43174309
MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
43184310
MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
4319-
assert(M1VT == getLMUL1VT(M1VT));
4311+
assert(M1VT == RISCVTargetLowering::getM1VT(M1VT));
43204312

43214313
// The following semantically builds up a fixed length concat_vector
43224314
// of the component build_vectors. We eagerly lower to scalable and
@@ -4356,7 +4348,7 @@ static SDValue lowerBUILD_VECTOR(SDValue Op, SelectionDAG &DAG,
43564348
count_if(Op->op_values(), [](const SDValue &V) { return V.isUndef(); });
43574349
unsigned NumDefElts = NumElts - NumUndefElts;
43584350
if (NumDefElts >= 8 && NumDefElts > NumElts / 2 &&
4359-
ContainerVT.bitsLE(getLMUL1VT(ContainerVT))) {
4351+
ContainerVT.bitsLE(RISCVTargetLowering::getM1VT(ContainerVT))) {
43604352
SmallVector<SDValue> SubVecAOps, SubVecBOps;
43614353
SmallVector<SDValue> MaskVals;
43624354
SDValue UndefElem = DAG.getUNDEF(Op->getOperand(0)->getValueType(0));
@@ -5114,7 +5106,8 @@ static SDValue lowerVZIP(unsigned Opc, SDValue Op0, SDValue Op1,
51145106

51155107
MVT InnerVT = ContainerVT;
51165108
auto [Mask, VL] = getDefaultVLOps(IntVT, InnerVT, DL, DAG, Subtarget);
5117-
if (Op1.isUndef() && ContainerVT.bitsGT(getLMUL1VT(ContainerVT)) &&
5109+
if (Op1.isUndef() &&
5110+
ContainerVT.bitsGT(RISCVTargetLowering::getM1VT(ContainerVT)) &&
51185111
(RISCVISD::RI_VUNZIP2A_VL == Opc || RISCVISD::RI_VUNZIP2B_VL == Opc)) {
51195112
InnerVT = ContainerVT.getHalfNumVectorElementsVT();
51205113
VL = DAG.getConstant(VT.getVectorNumElements() / 2, DL,
@@ -5382,7 +5375,7 @@ static SDValue lowerShuffleViaVRegSplitting(ShuffleVectorSDNode *SVN,
53825375
EVT ContainerVT = getContainerForFixedLengthVector(DAG, VT, Subtarget);
53835376
MVT OneRegVT = MVT::getVectorVT(ElemVT, ElemsPerVReg);
53845377
MVT M1VT = getContainerForFixedLengthVector(DAG, OneRegVT, Subtarget);
5385-
assert(M1VT == getLMUL1VT(M1VT));
5378+
assert(M1VT == RISCVTargetLowering::getM1VT(M1VT));
53865379
unsigned NumOpElts = M1VT.getVectorMinNumElements();
53875380
unsigned NumElts = ContainerVT.getVectorMinNumElements();
53885381
unsigned NumOfSrcRegs = NumElts / NumOpElts;
@@ -6152,7 +6145,7 @@ static SDValue lowerVECTOR_SHUFFLE(SDValue Op, SelectionDAG &DAG,
61526145
return convertFromScalableVector(VT, Gather, DAG, Subtarget);
61536146
}
61546147

6155-
const MVT M1VT = getLMUL1VT(ContainerVT);
6148+
const MVT M1VT = RISCVTargetLowering::getM1VT(ContainerVT);
61566149
EVT SubIndexVT = M1VT.changeVectorElementType(IndexVT.getScalarType());
61576150
auto [InnerTrueMask, InnerVL] =
61586151
getDefaultScalableVLOps(M1VT, DL, DAG, Subtarget);
@@ -7801,7 +7794,7 @@ SDValue RISCVTargetLowering::LowerOperation(SDValue Op,
78017794
// This reduces the length of the chain of vslideups and allows us to
78027795
// perform the vslideups at a smaller LMUL, limited to MF2.
78037796
if (Op.getNumOperands() > 2 &&
7804-
ContainerVT.bitsGE(getLMUL1VT(ContainerVT))) {
7797+
ContainerVT.bitsGE(RISCVTargetLowering::getM1VT(ContainerVT))) {
78057798
MVT HalfVT = VT.getHalfNumVectorElementsVT();
78067799
assert(isPowerOf2_32(Op.getNumOperands()));
78077800
size_t HalfNumOps = Op.getNumOperands() / 2;
@@ -9821,11 +9814,12 @@ getSmallestVTForIndex(MVT VecVT, unsigned MaxIdx, SDLoc DL, SelectionDAG &DAG,
98219814
const unsigned MinVLMAX = VectorBitsMin / EltSize;
98229815
MVT SmallerVT;
98239816
if (MaxIdx < MinVLMAX)
9824-
SmallerVT = getLMUL1VT(VecVT);
9817+
SmallerVT = RISCVTargetLowering::getM1VT(VecVT);
98259818
else if (MaxIdx < MinVLMAX * 2)
9826-
SmallerVT = getLMUL1VT(VecVT).getDoubleNumVectorElementsVT();
9819+
SmallerVT =
9820+
RISCVTargetLowering::getM1VT(VecVT).getDoubleNumVectorElementsVT();
98279821
else if (MaxIdx < MinVLMAX * 4)
9828-
SmallerVT = getLMUL1VT(VecVT)
9822+
SmallerVT = RISCVTargetLowering::getM1VT(VecVT)
98299823
.getDoubleNumVectorElementsVT()
98309824
.getDoubleNumVectorElementsVT();
98319825
if (!SmallerVT.isValid() || !VecVT.bitsGT(SmallerVT))
@@ -9898,9 +9892,8 @@ SDValue RISCVTargetLowering::lowerINSERT_VECTOR_ELT(SDValue Op,
98989892
// If we're compiling for an exact VLEN value, we can always perform
98999893
// the insert in m1 as we can determine the register corresponding to
99009894
// the index in the register group.
9901-
const MVT M1VT = getLMUL1VT(ContainerVT);
9902-
if (auto VLEN = Subtarget.getRealVLen();
9903-
VLEN && ContainerVT.bitsGT(M1VT)) {
9895+
const MVT M1VT = RISCVTargetLowering::getM1VT(ContainerVT);
9896+
if (auto VLEN = Subtarget.getRealVLen(); VLEN && ContainerVT.bitsGT(M1VT)) {
99049897
EVT ElemVT = VecVT.getVectorElementType();
99059898
unsigned ElemsPerVReg = *VLEN / ElemVT.getFixedSizeInBits();
99069899
unsigned RemIdx = OrigIdx % ElemsPerVReg;
@@ -10127,7 +10120,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
1012710120
const auto VLen = Subtarget.getRealVLen();
1012810121
if (auto *IdxC = dyn_cast<ConstantSDNode>(Idx);
1012910122
IdxC && VLen && VecVT.getSizeInBits().getKnownMinValue() > *VLen) {
10130-
MVT M1VT = getLMUL1VT(ContainerVT);
10123+
MVT M1VT = RISCVTargetLowering::getM1VT(ContainerVT);
1013110124
unsigned OrigIdx = IdxC->getZExtValue();
1013210125
EVT ElemVT = VecVT.getVectorElementType();
1013310126
unsigned ElemsPerVReg = *VLen / ElemVT.getFixedSizeInBits();
@@ -10175,7 +10168,8 @@ SDValue RISCVTargetLowering::lowerEXTRACT_VECTOR_ELT(SDValue Op,
1017510168
// TODO: We don't have the same code for insert_vector_elt because we
1017610169
// have BUILD_VECTOR and handle the degenerate case there. Should we
1017710170
// consider adding an inverse BUILD_VECTOR node?
10178-
MVT LMUL2VT = getLMUL1VT(ContainerVT).getDoubleNumVectorElementsVT();
10171+
MVT LMUL2VT =
10172+
RISCVTargetLowering::getM1VT(ContainerVT).getDoubleNumVectorElementsVT();
1017910173
if (ContainerVT.bitsGT(LMUL2VT) && VecVT.isFixedLengthVector())
1018010174
return SDValue();
1018110175

@@ -11107,7 +11101,7 @@ static SDValue lowerReductionSeq(unsigned RVVOpcode, MVT ResVT,
1110711101
SDValue VL, const SDLoc &DL, SelectionDAG &DAG,
1110811102
const RISCVSubtarget &Subtarget) {
1110911103
const MVT VecVT = Vec.getSimpleValueType();
11110-
const MVT M1VT = getLMUL1VT(VecVT);
11104+
const MVT M1VT = RISCVTargetLowering::getM1VT(VecVT);
1111111105
const MVT XLenVT = Subtarget.getXLenVT();
1111211106
const bool NonZeroAVL = isNonZeroAVL(VL);
1111311107

@@ -11485,8 +11479,8 @@ SDValue RISCVTargetLowering::lowerINSERT_SUBVECTOR(SDValue Op,
1148511479
assert(VLen);
1148611480
AlignedIdx /= *VLen / RISCV::RVVBitsPerBlock;
1148711481
}
11488-
if (ContainerVecVT.bitsGT(getLMUL1VT(ContainerVecVT))) {
11489-
InterSubVT = getLMUL1VT(ContainerVecVT);
11482+
if (ContainerVecVT.bitsGT(RISCVTargetLowering::getM1VT(ContainerVecVT))) {
11483+
InterSubVT = RISCVTargetLowering::getM1VT(ContainerVecVT);
1149011484
// Extract a subvector equal to the nearest full vector register type. This
1149111485
// should resolve to a EXTRACT_SUBREG instruction.
1149211486
AlignedExtract = DAG.getExtractSubvector(DL, InterSubVT, Vec, AlignedIdx);
@@ -11677,7 +11671,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
1167711671
// If the vector type is an LMUL-group type, extract a subvector equal to the
1167811672
// nearest full vector register type.
1167911673
MVT InterSubVT = VecVT;
11680-
if (VecVT.bitsGT(getLMUL1VT(VecVT))) {
11674+
if (VecVT.bitsGT(RISCVTargetLowering::getM1VT(VecVT))) {
1168111675
// If VecVT has an LMUL > 1, then SubVecVT should have a smaller LMUL, and
1168211676
// we should have successfully decomposed the extract into a subregister.
1168311677
// We use an extract_subvector that will resolve to a subreg extract.
@@ -11688,7 +11682,7 @@ SDValue RISCVTargetLowering::lowerEXTRACT_SUBVECTOR(SDValue Op,
1168811682
assert(VLen);
1168911683
Idx /= *VLen / RISCV::RVVBitsPerBlock;
1169011684
}
11691-
InterSubVT = getLMUL1VT(VecVT);
11685+
InterSubVT = RISCVTargetLowering::getM1VT(VecVT);
1169211686
Vec = DAG.getExtractSubvector(DL, InterSubVT, Vec, Idx);
1169311687
}
1169411688

@@ -11805,7 +11799,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_DEINTERLEAVE(SDValue Op,
1180511799
// For fractional LMUL, check if we can use a higher LMUL
1180611800
// instruction to avoid a vslidedown.
1180711801
if (SDValue Src = foldConcatVector(V1, V2);
11808-
Src && getLMUL1VT(VT).bitsGT(VT)) {
11802+
Src && RISCVTargetLowering::getM1VT(VT).bitsGT(VT)) {
1180911803
EVT NewVT = VT.getDoubleNumVectorElementsVT();
1181011804
Src = DAG.getExtractSubvector(DL, NewVT, Src, 0);
1181111805
// Freeze the source so we can increase its use count.
@@ -12187,7 +12181,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
1218712181
// vrgather.vv v14, v9, v16
1218812182
// vrgather.vv v13, v10, v16
1218912183
// vrgather.vv v12, v11, v16
12190-
if (ContainerVT.bitsGT(getLMUL1VT(ContainerVT)) &&
12184+
if (ContainerVT.bitsGT(RISCVTargetLowering::getM1VT(ContainerVT)) &&
1219112185
ContainerVT.getVectorElementCount().isKnownMultipleOf(2)) {
1219212186
auto [Lo, Hi] = DAG.SplitVector(Vec, DL);
1219312187
Lo = DAG.getNode(ISD::VECTOR_REVERSE, DL, Lo.getSimpleValueType(), Lo);
@@ -12252,7 +12246,7 @@ SDValue RISCVTargetLowering::lowerVECTOR_REVERSE(SDValue Op,
1225212246
// At LMUL > 1, do the index computation in 16 bits to reduce register
1225312247
// pressure.
1225412248
if (IntVT.getScalarType().bitsGT(MVT::i16) &&
12255-
IntVT.bitsGT(getLMUL1VT(IntVT))) {
12249+
IntVT.bitsGT(RISCVTargetLowering::getM1VT(IntVT))) {
1225612250
assert(isUInt<16>(MaxVLMAX - 1)); // Largest VLMAX is 65536 @ zvl65536b
1225712251
GatherOpc = RISCVISD::VRGATHEREI16_VV_VL;
1225812252
IntVT = IntVT.changeVectorElementType(MVT::i16);
@@ -12339,7 +12333,7 @@ RISCVTargetLowering::lowerFixedLengthVectorLoadToRVV(SDValue Op,
1233912333
const auto [MinVLMAX, MaxVLMAX] =
1234012334
RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
1234112335
if (MinVLMAX == MaxVLMAX && MinVLMAX == VT.getVectorNumElements() &&
12342-
getLMUL1VT(ContainerVT).bitsLE(ContainerVT)) {
12336+
RISCVTargetLowering::getM1VT(ContainerVT).bitsLE(ContainerVT)) {
1234312337
MachineMemOperand *MMO = Load->getMemOperand();
1234412338
SDValue NewLoad =
1234512339
DAG.getLoad(ContainerVT, DL, Load->getChain(), Load->getBasePtr(),
@@ -12400,7 +12394,7 @@ RISCVTargetLowering::lowerFixedLengthVectorStoreToRVV(SDValue Op,
1240012394
const auto [MinVLMAX, MaxVLMAX] =
1240112395
RISCVTargetLowering::computeVLMAXBounds(ContainerVT, Subtarget);
1240212396
if (MinVLMAX == MaxVLMAX && MinVLMAX == VT.getVectorNumElements() &&
12403-
getLMUL1VT(ContainerVT).bitsLE(ContainerVT)) {
12397+
RISCVTargetLowering::getM1VT(ContainerVT).bitsLE(ContainerVT)) {
1240412398
MachineMemOperand *MMO = Store->getMemOperand();
1240512399
return DAG.getStore(Store->getChain(), DL, NewValue, Store->getBasePtr(),
1240612400
MMO->getPointerInfo(), MMO->getBaseAlign(),
@@ -20368,7 +20362,7 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
2036820362
return Scalar.getOperand(0);
2036920363

2037020364
// Use M1 or smaller to avoid over constraining register allocation
20371-
const MVT M1VT = getLMUL1VT(VT);
20365+
const MVT M1VT = RISCVTargetLowering::getM1VT(VT);
2037220366
if (M1VT.bitsLT(VT)) {
2037320367
SDValue M1Passthru = DAG.getExtractSubvector(DL, M1VT, Passthru, 0);
2037420368
SDValue Result =
@@ -20382,15 +20376,15 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
2038220376
// no purpose.
2038320377
if (ConstantSDNode *Const = dyn_cast<ConstantSDNode>(Scalar);
2038420378
Const && !Const->isZero() && isInt<5>(Const->getSExtValue()) &&
20385-
VT.bitsLE(getLMUL1VT(VT)) && Passthru.isUndef())
20379+
VT.bitsLE(RISCVTargetLowering::getM1VT(VT)) && Passthru.isUndef())
2038620380
return DAG.getNode(RISCVISD::VMV_V_X_VL, DL, VT, Passthru, Scalar, VL);
2038720381

2038820382
break;
2038920383
}
2039020384
case RISCVISD::VMV_X_S: {
2039120385
SDValue Vec = N->getOperand(0);
2039220386
MVT VecVT = N->getOperand(0).getSimpleValueType();
20393-
const MVT M1VT = getLMUL1VT(VecVT);
20387+
const MVT M1VT = RISCVTargetLowering::getM1VT(VecVT);
2039420388
if (M1VT.bitsLT(VecVT)) {
2039520389
Vec = DAG.getExtractSubvector(DL, M1VT, Vec, 0);
2039620390
return DAG.getNode(RISCVISD::VMV_X_S, DL, N->getSimpleValueType(0), Vec);

llvm/lib/Target/RISCV/RISCVISelLowering.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,15 @@ class RISCVTargetLowering : public TargetLowering {
363363
static std::pair<unsigned, unsigned>
364364
computeVLMAXBounds(MVT ContainerVT, const RISCVSubtarget &Subtarget);
365365

366+
/// Given a vector (either fixed or scalable), return the scalable vector
367+
/// corresponding to a vector register (i.e. an m1 register group).
368+
static MVT getM1VT(MVT VT) {
369+
unsigned EltSizeInBits = VT.getVectorElementType().getSizeInBits();
370+
assert(EltSizeInBits <= RISCV::RVVBitsPerBlock && "Unexpected vector MVT");
371+
return MVT::getScalableVectorVT(VT.getVectorElementType(),
372+
RISCV::RVVBitsPerBlock / EltSizeInBits);
373+
}
374+
366375
static unsigned getRegClassIDForLMUL(RISCVVType::VLMUL LMul);
367376
static unsigned getSubregIndexByMVT(MVT VT, unsigned Index);
368377
static unsigned getRegClassIDForVecVT(MVT VT);

llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -602,15 +602,6 @@ InstructionCost RISCVTTIImpl::getSlideCost(FixedVectorType *Tp,
602602
return FirstSlideCost + SecondSlideCost + MaskCost;
603603
}
604604

605-
// Consolidate!
606-
static MVT getLMUL1VT(MVT VT) {
607-
assert(VT.getVectorElementType().getSizeInBits() <= RISCV::RVVBitsPerBlock &&
608-
"Unexpected vector MVT");
609-
return MVT::getScalableVectorVT(
610-
VT.getVectorElementType(),
611-
RISCV::RVVBitsPerBlock / VT.getVectorElementType().getSizeInBits());
612-
}
613-
614605
InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
615606
VectorType *Tp, ArrayRef<int> Mask,
616607
TTI::TargetCostKind CostKind,
@@ -870,7 +861,7 @@ InstructionCost RISCVTTIImpl::getShuffleCost(TTI::ShuffleKind Kind,
870861
MVT ContainerVT = LT.second;
871862
if (LT.second.isFixedLengthVector())
872863
ContainerVT = TLI->getContainerForFixedLengthVector(LT.second);
873-
MVT M1VT = getLMUL1VT(ContainerVT);
864+
MVT M1VT = RISCVTargetLowering::getM1VT(ContainerVT);
874865
if (ContainerVT.bitsLE(M1VT)) {
875866
// Example sequence:
876867
// csrr a0, vlenb

0 commit comments

Comments
 (0)