@@ -49838,31 +49838,24 @@ static SDValue combineMul(SDNode *N, SelectionDAG &DAG,
49838
49838
static SDValue combineShiftToPMULH(SDNode *N, SelectionDAG &DAG,
49839
49839
const SDLoc &DL,
49840
49840
const X86Subtarget &Subtarget) {
49841
+ using namespace SDPatternMatch;
49841
49842
assert((N->getOpcode() == ISD::SRL || N->getOpcode() == ISD::SRA) &&
49842
- "SRL or SRA node is required here!");
49843
+ "SRL or SRA node is required here!");
49843
49844
49844
49845
if (!Subtarget.hasSSE2())
49845
49846
return SDValue();
49846
49847
49847
- // The operation feeding into the shift must be a multiply.
49848
- SDValue ShiftOperand = N->getOperand(0);
49849
- if (ShiftOperand.getOpcode() != ISD::MUL || !ShiftOperand.hasOneUse())
49850
- return SDValue();
49851
-
49852
49848
// Input type should be at least vXi32.
49853
49849
EVT VT = N->getValueType(0);
49854
49850
if (!VT.isVector() || VT.getVectorElementType().getSizeInBits() < 32)
49855
49851
return SDValue();
49856
49852
49857
- // Need a shift by 16.
49858
- APInt ShiftAmt ;
49859
- if (!ISD::isConstantSplatVector (N->getOperand(1).getNode(), ShiftAmt ) ||
49860
- ShiftAmt != 16 )
49853
+ // The operation must be a multiply shifted right by 16.
49854
+ SDValue LHS, RHS ;
49855
+ if (!sd_match (N->getOperand(1), m_SpecificInt(16) ) ||
49856
+ !sd_match(N->getOperand(0), m_OneUse(m_Mul(m_Value(LHS), m_Value(RHS)))) )
49861
49857
return SDValue();
49862
49858
49863
- SDValue LHS = ShiftOperand.getOperand(0);
49864
- SDValue RHS = ShiftOperand.getOperand(1);
49865
-
49866
49859
unsigned ExtOpc = LHS.getOpcode();
49867
49860
if ((ExtOpc != ISD::SIGN_EXTEND && ExtOpc != ISD::ZERO_EXTEND) ||
49868
49861
RHS.getOpcode() != ExtOpc)
0 commit comments