@@ -9782,20 +9782,23 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9782
9782
return false;
9783
9783
9784
9784
EVT VT = Op.getValueType();
9785
+ EVT ExpectedVT = ExpectedOp.getValueType();
9786
+
9787
+ // Sources must be vectors and match the mask's element count.
9788
+ if (!VT.isVector() || !ExpectedVT.isVector() ||
9789
+ (int)VT.getVectorNumElements() != MaskSize ||
9790
+ (int)ExpectedVT.getVectorNumElements() != MaskSize)
9791
+ return false;
9792
+
9785
9793
switch (Op.getOpcode()) {
9786
9794
case ISD::BUILD_VECTOR:
9787
9795
// If the values are build vectors, we can look through them to find
9788
9796
// equivalent inputs that make the shuffles equivalent.
9789
- // TODO: Handle MaskSize != Op.getNumOperands()?
9790
- if (MaskSize == (int)Op.getNumOperands() &&
9791
- MaskSize == (int)ExpectedOp.getNumOperands())
9792
- return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9793
- break;
9797
+ return Op.getOperand(Idx) == ExpectedOp.getOperand(ExpectedIdx);
9794
9798
case ISD::BITCAST: {
9795
9799
SDValue Src = peekThroughBitcasts(Op);
9796
9800
EVT SrcVT = Src.getValueType();
9797
- if (Op == ExpectedOp && SrcVT.isVector() &&
9798
- (int)VT.getVectorNumElements() == MaskSize) {
9801
+ if (Op == ExpectedOp && SrcVT.isVector()) {
9799
9802
if ((SrcVT.getScalarSizeInBits() % VT.getScalarSizeInBits()) == 0) {
9800
9803
unsigned Scale = SrcVT.getScalarSizeInBits() / VT.getScalarSizeInBits();
9801
9804
return (Idx % Scale) == (ExpectedIdx % Scale) &&
@@ -9816,23 +9819,21 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9816
9819
}
9817
9820
case ISD::VECTOR_SHUFFLE: {
9818
9821
auto *SVN = cast<ShuffleVectorSDNode>(Op);
9819
- return Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize &&
9822
+ return Op == ExpectedOp &&
9820
9823
SVN->getMaskElt(Idx) == SVN->getMaskElt(ExpectedIdx);
9821
9824
}
9822
9825
case X86ISD::VBROADCAST:
9823
9826
case X86ISD::VBROADCAST_LOAD:
9824
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
9825
- return (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize);
9827
+ return Op == ExpectedOp;
9826
9828
case X86ISD::SUBV_BROADCAST_LOAD:
9827
- // TODO: Handle MaskSize != VT.getVectorNumElements()?
9828
- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize) {
9829
+ if (Op == ExpectedOp) {
9829
9830
auto *MemOp = cast<MemSDNode>(Op);
9830
9831
unsigned NumMemElts = MemOp->getMemoryVT().getVectorNumElements();
9831
9832
return (Idx % NumMemElts) == (ExpectedIdx % NumMemElts);
9832
9833
}
9833
9834
break;
9834
9835
case X86ISD::VPERMI: {
9835
- if (Op == ExpectedOp && (int)VT.getVectorNumElements() == MaskSize ) {
9836
+ if (Op == ExpectedOp) {
9836
9837
SmallVector<int, 8> Mask;
9837
9838
DecodeVPERMMask(MaskSize, Op.getConstantOperandVal(1), Mask);
9838
9839
SDValue Src = Op.getOperand(0);
@@ -9849,20 +9850,16 @@ static bool IsElementEquivalent(int MaskSize, SDValue Op, SDValue ExpectedOp,
9849
9850
case X86ISD::PACKSS:
9850
9851
case X86ISD::PACKUS:
9851
9852
// HOP(X,X) can refer to the elt from the lower/upper half of a lane.
9852
- // TODO: Handle MaskSize != NumElts?
9853
9853
// TODO: Handle HOP(X,Y) vs HOP(Y,X) equivalence cases.
9854
9854
if (Op == ExpectedOp && Op.getOperand(0) == Op.getOperand(1)) {
9855
9855
int NumElts = VT.getVectorNumElements();
9856
- if (MaskSize == NumElts) {
9857
- int NumLanes = VT.getSizeInBits() / 128;
9858
- int NumEltsPerLane = NumElts / NumLanes;
9859
- int NumHalfEltsPerLane = NumEltsPerLane / 2;
9860
- bool SameLane =
9861
- (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9862
- bool SameElt =
9863
- (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9864
- return SameLane && SameElt;
9865
- }
9856
+ int NumLanes = VT.getSizeInBits() / 128;
9857
+ int NumEltsPerLane = NumElts / NumLanes;
9858
+ int NumHalfEltsPerLane = NumEltsPerLane / 2;
9859
+ bool SameLane = (Idx / NumEltsPerLane) == (ExpectedIdx / NumEltsPerLane);
9860
+ bool SameElt =
9861
+ (Idx % NumHalfEltsPerLane) == (ExpectedIdx % NumHalfEltsPerLane);
9862
+ return SameLane && SameElt;
9866
9863
}
9867
9864
break;
9868
9865
}
0 commit comments