Skip to content

Commit e6d62c9

Browse files
committed
[X86] IsElementEquivalent - pull out vector element count mismatch code. NFC.
All cases rely on the ops having the same vector count as the masksize, and this is unlikely to change now that we handle bitcasts, so just early out.
1 parent 155fd97 commit e6d62c9

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
97829782
return false;
97839783

97849784
EVT VT = Op.getValueType();
9785+
EVT ExpectedVT = ExpectedOp.getValueType();
9786+
9787+
// Sources must be vectors and match the mask's element count.
9788+
if (!VT.isVector() || !ExpectedVT.isVector() ||
9789+
(int)VT.getVectorNumElements() != MaskSize ||
9790+
(int)ExpectedVT.getVectorNumElements() != MaskSize)
9791+
return false;
9792+
97859793
switch (Op.getOpcode()) {
97869794
case ISD::BUILD_VECTOR:
97879795
// If the values are build vectors, we can look through them to find
97889796
// equivalent inputs that make the shuffles equivalent.
9789-
// TODO: Handle MaskSize != Op.getNumOperands()?
9790-
if (MaskSize == (int)Op.getNumOperands() &&
9791-
MaskSize == (int)ExpectedOp.getNumOperands())
9792-
return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9793-
break;
9797+
return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
97949798
case ISD::BITCAST: {
97959799
SDValue Src = peekThroughBitcasts(Op);
97969800
EVT SrcVT = Src.getValueType();
9797-
if (Op == ExpectedOp && SrcVT.isVector() &&
9798-
(int)VT.getVectorNumElements() == MaskSize) {
9801+
if (Op == ExpectedOp && SrcVT.isVector()) {
97999802
if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
98009803
unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
98019804
return (Idx % Scale) == (ExpectedIdx % Scale) &&
@@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
98169819
}
98179820
case ISD::VECTOR_SHUFFLE: {
98189821
auto *SVN = cast<ShuffleVectorSDNode>(Op);
9819-
return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
9822+
return Op == ExpectedOp &&
98209823
SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
98219824
}
98229825
case X86ISD::VBROADCAST:
98239826
case X86ISD::VBROADCAST_LOAD:
9824-
// TODO: Handle MaskSize != VT.getVectorNumElements()?
9825-
return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
9827+
return Op == ExpectedOp;
98269828
case X86ISD::SUBV_BROADCAST_LOAD:
9827-
// TODO: Handle MaskSize != VT.getVectorNumElements()?
9828-
if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
9829+
if (Op == ExpectedOp) {
98299830
auto *MemOp = cast<MemSDNode>(Op);
98309831
unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
98319832
return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
98329833
}
98339834
break;
98349835
case X86ISD::VPERMI: {
9835-
if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
9836+
if (Op == ExpectedOp) {
98369837
SmallVector<int, 8> Mask;
98379838
DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask);
98389839
SDValue Src = Op.getOperand(0);
@@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
98499850
case X86ISD::PACKSS:
98509851
case X86ISD::PACKUS:
98519852
// HOP(X,X) can refer to the elt from the lower/upper half of a lane.
9852-
// TODO: Handle MaskSize != NumElts?
98539853
// TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
98549854
if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
98559855
int NumElts = VT.getVectorNumElements();
9856-
if (MaskSize == NumElts) {
9857-
int NumLanes = VT.getSizeInBits() / 128;
9858-
int NumEltsPerLane = NumElts / NumLanes;
9859-
int NumHalfEltsPerLane = NumEltsPerLane / 2;
9860-
bool SameLane =
9861-
(Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9862-
bool SameElt =
9863-
(Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9864-
return SameLane && SameElt;
9865-
}
9856+
int NumLanes = VT.getSizeInBits() / 128;
9857+
int NumEltsPerLane = NumElts / NumLanes;
9858+
int NumHalfEltsPerLane = NumEltsPerLane / 2;
9859+
bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9860+
bool SameElt =
9861+
(Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9862+
return SameLane && SameElt;
98669863
}
98679864
break;
98689865
}

0 commit comments

Comments
 (0)