Skip to content

Commit f0bc411

Browse files
authored
[X86] combineBasicSADPattern - pattern match various vXi8 ABDU patterns (#147570)
We were previously limited to abs(sub(zext(),zext()) patterns, but add handling for a number of other abdu patterns until a topological sorted dag allows us to rely on a ABDU node having already been created. Now that we don't just match zext() sources, I've generalised the createPSADBW helper to explicitly zext/truncate to the expected vXi8 source type - it still assumes the sources are correct for a PSADBW node. Fixes #143456
1 parent 0b49f2f commit f0bc411

File tree

2 files changed

+31
-35
lines changed

2 files changed

+31
-35
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 31 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -46047,23 +46047,22 @@ static SDValue createVPDPBUSD(SelectionDAG &DAG, SDValue LHS, SDValue RHS,
4604746047
DpBuilder, false);
4604846048
}
4604946049

46050-
// Given two zexts of <k x i8> to <k x i32>, create a PSADBW of the inputs
46051-
// to these zexts.
46052-
static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
46053-
const SDValue &Zext1, const SDLoc &DL,
46054-
const X86Subtarget &Subtarget) {
46050+
// Create a PSADBW given two sources representable as zexts of vXi8.
46051+
static SDValue createPSADBW(SelectionDAG &DAG, SDValue N0, SDValue N1,
46052+
const SDLoc &DL, const X86Subtarget &Subtarget) {
4605546053
// Find the appropriate width for the PSADBW.
46056-
EVT InVT = Zext0.getOperand(0).getValueType();
46057-
unsigned RegSize = std::max(128u, (unsigned)InVT.getSizeInBits());
46058-
46059-
// "Zero-extend" the i8 vectors. This is not a per-element zext, rather we
46060-
// fill in the missing vector elements with 0.
46061-
unsigned NumConcat = RegSize / InVT.getSizeInBits();
46062-
SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, InVT));
46063-
Ops[0] = Zext0.getOperand(0);
46054+
EVT DstVT = N0.getValueType();
46055+
EVT SrcVT = EVT::getVectorVT(*DAG.getContext(), MVT::i8,
46056+
DstVT.getVectorElementCount());
46057+
unsigned RegSize = std::max(128u, (unsigned)SrcVT.getSizeInBits());
46058+
46059+
// Widen the vXi8 vectors, padding with zero vector elements.
46060+
unsigned NumConcat = RegSize / SrcVT.getSizeInBits();
46061+
SmallVector<SDValue, 16> Ops(NumConcat, DAG.getConstant(0, DL, SrcVT));
46062+
Ops[0] = DAG.getZExtOrTrunc(N0, DL, SrcVT);
4606446063
MVT ExtendedVT = MVT::getVectorVT(MVT::i8, RegSize / 8);
4606546064
SDValue SadOp0 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
46066-
Ops[0] = Zext1.getOperand(0);
46065+
Ops[0] = DAG.getZExtOrTrunc(N1, DL, SrcVT);
4606746066
SDValue SadOp1 = DAG.getNode(ISD::CONCAT_VECTORS, DL, ExtendedVT, Ops);
4606846067

4606946068
// Actually build the SAD, split as 128/256/512 bits for SSE/AVX2/AVX512BW.
@@ -46073,7 +46072,7 @@ static SDValue createPSADBW(SelectionDAG &DAG, const SDValue &Zext0,
4607346072
return DAG.getNode(X86ISD::PSADBW, DL, VT, Ops);
4607446073
};
4607546074
MVT SadVT = MVT::getVectorVT(MVT::i64, RegSize / 64);
46076-
return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, { SadOp0, SadOp1 },
46075+
return SplitOpsAndApply(DAG, Subtarget, DL, SadVT, {SadOp0, SadOp1},
4607746076
PSADBWBuilder);
4607846077
}
4607946078

@@ -46372,9 +46371,8 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
4637246371
return SDValue();
4637346372

4637446373
EVT ExtractVT = Extract->getValueType(0);
46375-
// Verify the type we're extracting is either i32 or i64.
46376-
// FIXME: Could support other types, but this is what we have coverage for.
46377-
if (ExtractVT != MVT::i32 && ExtractVT != MVT::i64)
46374+
if (ExtractVT != MVT::i8 && ExtractVT != MVT::i16 && ExtractVT != MVT::i32 &&
46375+
ExtractVT != MVT::i64)
4637846376
return SDValue();
4637946377

4638046378
EVT VT = Extract->getOperand(0).getValueType();
@@ -46399,20 +46397,27 @@ static SDValue combineBasicSADPattern(SDNode *Extract, SelectionDAG &DAG,
4639946397
Root.getOpcode() == ISD::ANY_EXTEND)
4640046398
Root = Root.getOperand(0);
4640146399

46402-
// Check whether we have an abdu pattern.
46403-
// TODO: Add handling for ISD::ABDU.
46404-
SDValue Zext0, Zext1;
46400+
// Check whether we have an vXi8 abdu pattern.
46401+
// TODO: Just match ISD::ABDU once the DAG is topological sorted.
46402+
SDValue Src0, Src1;
4640546403
if (!sd_match(
4640646404
Root,
46407-
m_Abs(m_Sub(m_AllOf(m_Value(Zext0),
46408-
m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46409-
m_AllOf(m_Value(Zext1),
46410-
m_ZExt(m_SpecificVectorElementVT(MVT::i8)))))))
46405+
m_AnyOf(
46406+
m_SpecificVectorElementVT(
46407+
MVT::i8, m_c_BinOp(ISD::ABDU, m_Value(Src0), m_Value(Src1))),
46408+
m_SpecificVectorElementVT(
46409+
MVT::i8, m_Sub(m_UMax(m_Value(Src0), m_Value(Src1)),
46410+
m_UMin(m_Deferred(Src0), m_Deferred(Src1)))),
46411+
m_Abs(
46412+
m_Sub(m_AllOf(m_Value(Src0),
46413+
m_ZExt(m_SpecificVectorElementVT(MVT::i8))),
46414+
m_AllOf(m_Value(Src1),
46415+
m_ZExt(m_SpecificVectorElementVT(MVT::i8))))))))
4641146416
return SDValue();
4641246417

4641346418
// Create the SAD instruction.
4641446419
SDLoc DL(Extract);
46415-
SDValue SAD = createPSADBW(DAG, Zext0, Zext1, DL, Subtarget);
46420+
SDValue SAD = createPSADBW(DAG, Src0, Src1, DL, Subtarget);
4641646421

4641746422
// If the original vector was wider than 8 elements, sum over the results
4641846423
// in the SAD vector.

llvm/test/CodeGen/X86/sad.ll

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,11 +1184,6 @@ define i32 @PR143456(ptr %p0, ptr %p1) {
11841184
; SSE2: # %bb.0:
11851185
; SSE2-NEXT: movq {{.*#+}} xmm0 = mem[0],zero
11861186
; SSE2-NEXT: movq {{.*#+}} xmm1 = mem[0],zero
1187-
; SSE2-NEXT: movdqa %xmm0, %xmm2
1188-
; SSE2-NEXT: pminub %xmm1, %xmm2
1189-
; SSE2-NEXT: pmaxub %xmm1, %xmm0
1190-
; SSE2-NEXT: psubb %xmm2, %xmm0
1191-
; SSE2-NEXT: pxor %xmm1, %xmm1
11921187
; SSE2-NEXT: psadbw %xmm0, %xmm1
11931188
; SSE2-NEXT: movd %xmm1, %eax
11941189
; SSE2-NEXT: movzbl %al, %eax
@@ -1198,10 +1193,6 @@ define i32 @PR143456(ptr %p0, ptr %p1) {
11981193
; AVX: # %bb.0:
11991194
; AVX-NEXT: vmovq {{.*#+}} xmm0 = mem[0],zero
12001195
; AVX-NEXT: vmovq {{.*#+}} xmm1 = mem[0],zero
1201-
; AVX-NEXT: vpminub %xmm1, %xmm0, %xmm2
1202-
; AVX-NEXT: vpmaxub %xmm1, %xmm0, %xmm0
1203-
; AVX-NEXT: vpsubb %xmm2, %xmm0, %xmm0
1204-
; AVX-NEXT: vpxor %xmm1, %xmm1, %xmm1
12051196
; AVX-NEXT: vpsadbw %xmm1, %xmm0, %xmm0
12061197
; AVX-NEXT: vpextrb $0, %xmm0, %eax
12071198
; AVX-NEXT: retq

0 commit comments

Comments
 (0)