Skip to content

Commit e86a92f

Browse files
[AArch64][SelectionDAG] Add support for 8to64 partial reduction cases (#138269)
--------- Co-authored-by: James Chesterman <james.chesterman@arm.com>
1 parent a8ed244 commit e86a92f

File tree

3 files changed

+93
-81
lines changed

3 files changed

+93
-81
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1868,6 +1868,8 @@ AArch64TargetLowering::AArch64TargetLowering(const TargetMachine &TM,
18681868
// Other pairs will default to 'Expand'.
18691869
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv8i16, Legal);
18701870
setPartialReduceMLAAction(MVT::nxv4i32, MVT::nxv16i8, Legal);
1871+
1872+
setPartialReduceMLAAction(MVT::nxv2i64, MVT::nxv16i8, Custom);
18711873
}
18721874

18731875
// Handle operations that are only available in non-streaming SVE mode.
@@ -7740,6 +7742,9 @@ SDValue AArch64TargetLowering::LowerOperation(SDValue Op,
77407742
return LowerFLDEXP(Op, DAG);
77417743
case ISD::EXPERIMENTAL_VECTOR_HISTOGRAM:
77427744
return LowerVECTOR_HISTOGRAM(Op, DAG);
7745+
case ISD::PARTIAL_REDUCE_SMLA:
7746+
case ISD::PARTIAL_REDUCE_UMLA:
7747+
return LowerPARTIAL_REDUCE_MLA(Op, DAG);
77437748
}
77447749
}
77457750

@@ -29476,6 +29481,40 @@ SDValue AArch64TargetLowering::LowerVECTOR_HISTOGRAM(SDValue Op,
2947629481
return Scatter;
2947729482
}
2947829483

29484+
/// If a PARTIAL_REDUCE_MLA node comes in with an accumulator-input type pairing
29485+
/// of nxv2i64/nxv16i8, we cannot directly lower it to a (u|s)dot. We can
29486+
/// however still make use of the dot product instruction by instead
29487+
/// accumulating over two steps: nxv16i8 -> nxv4i32 -> nxv2i64.
29488+
SDValue
29489+
AArch64TargetLowering::LowerPARTIAL_REDUCE_MLA(SDValue Op,
29490+
SelectionDAG &DAG) const {
29491+
SDLoc DL(Op);
29492+
29493+
SDValue Acc = Op.getOperand(0);
29494+
SDValue LHS = Op.getOperand(1);
29495+
SDValue RHS = Op.getOperand(2);
29496+
EVT ResultVT = Op.getValueType();
29497+
assert(ResultVT == MVT::nxv2i64 && LHS.getValueType() == MVT::nxv16i8);
29498+
29499+
SDValue DotNode = DAG.getNode(Op.getOpcode(), DL, MVT::nxv4i32,
29500+
DAG.getConstant(0, DL, MVT::nxv4i32), LHS, RHS);
29501+
29502+
bool IsUnsigned = Op.getOpcode() == ISD::PARTIAL_REDUCE_UMLA;
29503+
if (Subtarget->hasSVE2() || Subtarget->isStreamingSVEAvailable()) {
29504+
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UADDWB : AArch64ISD::SADDWB;
29505+
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UADDWT : AArch64ISD::SADDWT;
29506+
SDValue Lo = DAG.getNode(LoOpcode, DL, ResultVT, Acc, DotNode);
29507+
return DAG.getNode(HiOpcode, DL, ResultVT, Lo, DotNode);
29508+
}
29509+
29510+
unsigned LoOpcode = IsUnsigned ? AArch64ISD::UUNPKLO : AArch64ISD::SUNPKLO;
29511+
unsigned HiOpcode = IsUnsigned ? AArch64ISD::UUNPKHI : AArch64ISD::SUNPKHI;
29512+
auto Lo = DAG.getNode(LoOpcode, DL, ResultVT, DotNode);
29513+
auto Hi = DAG.getNode(HiOpcode, DL, ResultVT, DotNode);
29514+
auto Extended = DAG.getNode(ISD::ADD, DL, ResultVT, Lo, Hi);
29515+
return DAG.getNode(ISD::ADD, DL, ResultVT, Acc, Extended);
29516+
}
29517+
2947929518
SDValue
2948029519
AArch64TargetLowering::LowerFixedLengthFPToIntToSVE(SDValue Op,
2948129520
SelectionDAG &DAG) const {

llvm/lib/Target/AArch64/AArch64ISelLowering.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1181,6 +1181,7 @@ class AArch64TargetLowering : public TargetLowering {
11811181
SDValue LowerVECTOR_DEINTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11821182
SDValue LowerVECTOR_INTERLEAVE(SDValue Op, SelectionDAG &DAG) const;
11831183
SDValue LowerVECTOR_HISTOGRAM(SDValue Op, SelectionDAG &DAG) const;
1184+
SDValue LowerPARTIAL_REDUCE_MLA(SDValue Op, SelectionDAG &DAG) const;
11841185
SDValue LowerDIV(SDValue Op, SelectionDAG &DAG) const;
11851186
SDValue LowerMUL(SDValue Op, SelectionDAG &DAG) const;
11861187
SDValue LowerVectorSRA_SRL_SHL(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/AArch64/sve-partial-reduce-dot-product.ll

Lines changed: 53 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
22
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-I8MM
33
; RUN: llc -mtriple=aarch64 -mattr=+sve2 %s -o - | FileCheck %s --check-prefixes=CHECK,CHECK-NOI8MM
4-
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING
4+
; RUN: llc -mtriple=aarch64 -mattr=+sve,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE
5+
; RUN: llc -mtriple=aarch64 -mattr=+sve2,+i8mm -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SVE2
6+
; RUN: llc -mtriple=aarch64 -mattr=+sme -force-streaming -aarch64-enable-partial-reduce-nodes %s -o - | FileCheck %s --check-prefixes=CHECK-NEWLOWERING,CHECK-NEWLOWERING-SME
57

68
define <vscale x 4 x i32> @udot(<vscale x 4 x i32> %acc, <vscale x 16 x i8> %a, <vscale x 16 x i8> %b) {
79
; CHECK-LABEL: udot:
@@ -196,46 +198,31 @@ define <vscale x 4 x i64> @udot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
196198
; CHECK-NEXT: add z1.d, z1.d, z3.d
197199
; CHECK-NEXT: ret
198200
;
199-
; CHECK-NEWLOWERING-LABEL: udot_8to64:
200-
; CHECK-NEWLOWERING: // %bb.0: // %entry
201-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.h, z2.b
202-
; CHECK-NEWLOWERING-NEXT: uunpklo z2.h, z2.b
203-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.h, z3.b
204-
; CHECK-NEWLOWERING-NEXT: uunpklo z3.h, z3.b
205-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
206-
; CHECK-NEWLOWERING-NEXT: uunpklo z6.s, z4.h
207-
; CHECK-NEWLOWERING-NEXT: uunpklo z7.s, z2.h
208-
; CHECK-NEWLOWERING-NEXT: uunpklo z24.s, z5.h
209-
; CHECK-NEWLOWERING-NEXT: uunpklo z25.s, z3.h
210-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.s, z4.h
211-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.s, z2.h
212-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.s, z5.h
213-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.s, z3.h
214-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z6.s
215-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z7.s
216-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z24.s
217-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z25.s
218-
; CHECK-NEWLOWERING-NEXT: uunpkhi z6.d, z6.s
219-
; CHECK-NEWLOWERING-NEXT: uunpkhi z7.d, z7.s
220-
; CHECK-NEWLOWERING-NEXT: uunpkhi z24.d, z24.s
221-
; CHECK-NEWLOWERING-NEXT: uunpkhi z25.d, z25.s
222-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
223-
; CHECK-NEWLOWERING-NEXT: uunpklo z26.d, z4.s
224-
; CHECK-NEWLOWERING-NEXT: uunpklo z28.d, z5.s
225-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
226-
; CHECK-NEWLOWERING-NEXT: uunpklo z27.d, z2.s
227-
; CHECK-NEWLOWERING-NEXT: uunpklo z29.d, z3.s
228-
; CHECK-NEWLOWERING-NEXT: uunpkhi z4.d, z4.s
229-
; CHECK-NEWLOWERING-NEXT: uunpkhi z2.d, z2.s
230-
; CHECK-NEWLOWERING-NEXT: uunpkhi z5.d, z5.s
231-
; CHECK-NEWLOWERING-NEXT: uunpkhi z3.d, z3.s
232-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
233-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
234-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
235-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
236-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
237-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
238-
; CHECK-NEWLOWERING-NEXT: ret
201+
; CHECK-NEWLOWERING-SVE-LABEL: udot_8to64:
202+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
203+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
204+
; CHECK-NEWLOWERING-SVE-NEXT: udot z4.s, z2.b, z3.b
205+
; CHECK-NEWLOWERING-SVE-NEXT: uunpkhi z2.d, z4.s
206+
; CHECK-NEWLOWERING-SVE-NEXT: uunpklo z3.d, z4.s
207+
; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
208+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
209+
; CHECK-NEWLOWERING-SVE-NEXT: ret
210+
;
211+
; CHECK-NEWLOWERING-SVE2-LABEL: udot_8to64:
212+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
213+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
214+
; CHECK-NEWLOWERING-SVE2-NEXT: udot z4.s, z2.b, z3.b
215+
; CHECK-NEWLOWERING-SVE2-NEXT: uaddwb z0.d, z0.d, z4.s
216+
; CHECK-NEWLOWERING-SVE2-NEXT: uaddwt z0.d, z0.d, z4.s
217+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
218+
;
219+
; CHECK-NEWLOWERING-SME-LABEL: udot_8to64:
220+
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
221+
; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
222+
; CHECK-NEWLOWERING-SME-NEXT: udot z4.s, z2.b, z3.b
223+
; CHECK-NEWLOWERING-SME-NEXT: uaddwb z0.d, z0.d, z4.s
224+
; CHECK-NEWLOWERING-SME-NEXT: uaddwt z0.d, z0.d, z4.s
225+
; CHECK-NEWLOWERING-SME-NEXT: ret
239226
entry:
240227
%a.wide = zext <vscale x 16 x i8> %a to <vscale x 16 x i64>
241228
%b.wide = zext <vscale x 16 x i8> %b to <vscale x 16 x i64>
@@ -256,46 +243,31 @@ define <vscale x 4 x i64> @sdot_8to64(<vscale x 4 x i64> %acc, <vscale x 16 x i8
256243
; CHECK-NEXT: add z1.d, z1.d, z3.d
257244
; CHECK-NEXT: ret
258245
;
259-
; CHECK-NEWLOWERING-LABEL: sdot_8to64:
260-
; CHECK-NEWLOWERING: // %bb.0: // %entry
261-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.h, z2.b
262-
; CHECK-NEWLOWERING-NEXT: sunpklo z2.h, z2.b
263-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.h, z3.b
264-
; CHECK-NEWLOWERING-NEXT: sunpklo z3.h, z3.b
265-
; CHECK-NEWLOWERING-NEXT: ptrue p0.d
266-
; CHECK-NEWLOWERING-NEXT: sunpklo z6.s, z4.h
267-
; CHECK-NEWLOWERING-NEXT: sunpklo z7.s, z2.h
268-
; CHECK-NEWLOWERING-NEXT: sunpklo z24.s, z5.h
269-
; CHECK-NEWLOWERING-NEXT: sunpklo z25.s, z3.h
270-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.s, z4.h
271-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.s, z2.h
272-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.s, z5.h
273-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.s, z3.h
274-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z6.s
275-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z7.s
276-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z24.s
277-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z25.s
278-
; CHECK-NEWLOWERING-NEXT: sunpkhi z6.d, z6.s
279-
; CHECK-NEWLOWERING-NEXT: sunpkhi z7.d, z7.s
280-
; CHECK-NEWLOWERING-NEXT: sunpkhi z24.d, z24.s
281-
; CHECK-NEWLOWERING-NEXT: sunpkhi z25.d, z25.s
282-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
283-
; CHECK-NEWLOWERING-NEXT: sunpklo z26.d, z4.s
284-
; CHECK-NEWLOWERING-NEXT: sunpklo z28.d, z5.s
285-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
286-
; CHECK-NEWLOWERING-NEXT: sunpklo z27.d, z2.s
287-
; CHECK-NEWLOWERING-NEXT: sunpklo z29.d, z3.s
288-
; CHECK-NEWLOWERING-NEXT: sunpkhi z4.d, z4.s
289-
; CHECK-NEWLOWERING-NEXT: sunpkhi z2.d, z2.s
290-
; CHECK-NEWLOWERING-NEXT: sunpkhi z5.d, z5.s
291-
; CHECK-NEWLOWERING-NEXT: sunpkhi z3.d, z3.s
292-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z6.d, z24.d
293-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z7.d, z25.d
294-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z26.d, z28.d
295-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z27.d, z29.d
296-
; CHECK-NEWLOWERING-NEXT: mla z1.d, p0/m, z4.d, z5.d
297-
; CHECK-NEWLOWERING-NEXT: mla z0.d, p0/m, z2.d, z3.d
298-
; CHECK-NEWLOWERING-NEXT: ret
246+
; CHECK-NEWLOWERING-SVE-LABEL: sdot_8to64:
247+
; CHECK-NEWLOWERING-SVE: // %bb.0: // %entry
248+
; CHECK-NEWLOWERING-SVE-NEXT: movi v4.2d, #0000000000000000
249+
; CHECK-NEWLOWERING-SVE-NEXT: sdot z4.s, z2.b, z3.b
250+
; CHECK-NEWLOWERING-SVE-NEXT: sunpkhi z2.d, z4.s
251+
; CHECK-NEWLOWERING-SVE-NEXT: sunpklo z3.d, z4.s
252+
; CHECK-NEWLOWERING-SVE-NEXT: add z2.d, z3.d, z2.d
253+
; CHECK-NEWLOWERING-SVE-NEXT: add z0.d, z0.d, z2.d
254+
; CHECK-NEWLOWERING-SVE-NEXT: ret
255+
;
256+
; CHECK-NEWLOWERING-SVE2-LABEL: sdot_8to64:
257+
; CHECK-NEWLOWERING-SVE2: // %bb.0: // %entry
258+
; CHECK-NEWLOWERING-SVE2-NEXT: movi v4.2d, #0000000000000000
259+
; CHECK-NEWLOWERING-SVE2-NEXT: sdot z4.s, z2.b, z3.b
260+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwb z0.d, z0.d, z4.s
261+
; CHECK-NEWLOWERING-SVE2-NEXT: saddwt z0.d, z0.d, z4.s
262+
; CHECK-NEWLOWERING-SVE2-NEXT: ret
263+
;
264+
; CHECK-NEWLOWERING-SME-LABEL: sdot_8to64:
265+
; CHECK-NEWLOWERING-SME: // %bb.0: // %entry
266+
; CHECK-NEWLOWERING-SME-NEXT: mov z4.s, #0 // =0x0
267+
; CHECK-NEWLOWERING-SME-NEXT: sdot z4.s, z2.b, z3.b
268+
; CHECK-NEWLOWERING-SME-NEXT: saddwb z0.d, z0.d, z4.s
269+
; CHECK-NEWLOWERING-SME-NEXT: saddwt z0.d, z0.d, z4.s
270+
; CHECK-NEWLOWERING-SME-NEXT: ret
299271
entry:
300272
%a.wide = sext <vscale x 16 x i8> %a to <vscale x 16 x i64>
301273
%b.wide = sext <vscale x 16 x i8> %b to <vscale x 16 x i64>

0 commit comments

Comments
 (0)