@@ -46047,23 +46047,22 @@ static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
46047
46047
DpBuilder, false);
46048
46048
}
46049
46049
46050
- // Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
46051
- // to these zexts.
46052
- static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
46053
- const SDValue &Zext1, const SDLoc &DL,
46054
- const X86Subtarget &Subtarget) {
46050
+ // Create a PSADBW given two sources representable as zexts of vXi8.
46051
+ static SDValue createPSADBW(SelectionDAG &DAG, SDValue N0, SDValue N1,
46052
+ const SDLoc &DL, const X86Subtarget &Subtarget) {
46055
46053
// Find the appropriate width for the PSADBW.
46056
- EVT InVT = Zext0.getOperand(0).getValueType();
46057
- unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
46058
-
46059
- // "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
46060
- // fill in the missing vector elements with 0.
46061
- unsigned NumConcat = RegSize / InVT.getSizeInBits();
46062
- SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
46063
- Ops[0] = Zext0.getOperand(0);
46054
+ EVT DstVT = N0.getValueType();
46055
+ EVT SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8,
46056
+ DstVT.getVectorElementCount());
46057
+ unsigned RegSize = std::max(128u, (unsigned)SrcVT.getSizeInBits());
46058
+
46059
+ // Widen the vXi8 vectors, padding with zero vector elements.
46060
+ unsigned NumConcat = RegSize / SrcVT.getSizeInBits();
46061
+ SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, SrcVT));
46062
+ Ops[0] = DAG.getZExtOrTrunc(N0, DL, SrcVT);
46064
46063
MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
46065
46064
SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
46066
- Ops[0] = Zext1.getOperand(0 );
46065
+ Ops[0] = DAG.getZExtOrTrunc(N1, DL, SrcVT );
46067
46066
SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
46068
46067
46069
46068
// Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
@@ -46073,7 +46072,7 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
46073
46072
return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
46074
46073
};
46075
46074
MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
46076
- return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
46075
+ return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, {SadOp0, SadOp1},
46077
46076
PSADBWBuilder);
46078
46077
}
46079
46078
@@ -46372,9 +46371,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
46372
46371
return SDValue();
46373
46372
46374
46373
EVT ExtractVT = Extract->getValueType(0);
46375
- // Verify the type we're extracting is either i32 or i64.
46376
- // FIXME: Could support other types, but this is what we have coverage for.
46377
- if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
46374
+ if (ExtractVT != MVT::i8 && ExtractVT != MVT::i16 && ExtractVT != MVT::i32 &&
46375
+ ExtractVT != MVT::i64)
46378
46376
return SDValue();
46379
46377
46380
46378
EVT VT = Extract->getOperand(0).getValueType();
@@ -46399,20 +46397,27 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
46399
46397
Root.getOpcode() == ISD::ANY_EXTEND)
46400
46398
Root = Root.getOperand(0);
46401
46399
46402
- // Check whether we have an abdu pattern.
46403
- // TODO: Add handling for ISD::ABDU.
46404
- SDValue Zext0, Zext1 ;
46400
+ // Check whether we have an vXi8 abdu pattern.
46401
+ // TODO: Just match ISD::ABDU once the DAG is topological sorted .
46402
+ SDValue Src0, Src1 ;
46405
46403
if (!sd_match(
46406
46404
Root,
46407
- m_Abs(m_Sub(m_AllOf(m_Value(Zext0),
46408
- m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46409
- m_AllOf(m_Value(Zext1),
46410
- m_ZExt(m_SpecificVectorElementVT(MVT::i8)))))))
46405
+ m_AnyOf(
46406
+ m_SpecificVectorElementVT(
46407
+ MVT::i8, m_c_BinOp(ISD::ABDU, m_Value(Src0), m_Value(Src1))),
46408
+ m_SpecificVectorElementVT(
46409
+ MVT::i8, m_Sub(m_UMax(m_Value(Src0), m_Value(Src1)),
46410
+ m_UMin(m_Deferred(Src0), m_Deferred(Src1)))),
46411
+ m_Abs(
46412
+ m_Sub(m_AllOf(m_Value(Src0),
46413
+ m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46414
+ m_AllOf(m_Value(Src1),
46415
+ m_ZExt(m_SpecificVectorElementVT(MVT::i8))))))))
46411
46416
return SDValue();
46412
46417
46413
46418
// Create the SAD instruction.
46414
46419
SDLoc DL(Extract);
46415
- SDValue SAD = createPSADBW(DAG, Zext0, Zext1 , DL, Subtarget);
46420
+ SDValue SAD = createPSADBW(DAG, Src0, Src1 , DL, Subtarget);
46416
46421
46417
46422
// If the original vector was wider than 8 elements, sum over the results
46418
46423
// in the SAD vector.
0 commit comments