Skip to content

Commit 7e64ade

Browse files
preameslukel97
andauthored
[RISCV] Extend zvqdot matching to handle reduction trees (#138965)
Now that we have matching for vqdot in it's basic variants, we can extend the matcher to handle reduction trees instead of individual reductions. This is important as we canonicalize reductions by performing a tree in the vector domain before the root reduction instruction. The particular approach taken here has the unfortunate implication that non-matches visit the entire reduction tree once for each time the reduction root is visited in DAG. While conceptually problematic for compile time, this is probably fine in practice as we should only visit the root once per pass of DAGCombine. I don't really see a better solution - suggestions welcome. --------- Co-authored-by: Luke Lau <luke_lau@icloud.com>
1 parent 20d6def commit 7e64ade

File tree

2 files changed

+143
-47
lines changed

2 files changed

+143
-47
lines changed

llvm/lib/Target/RISCV/RISCVISelLowering.cpp

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18054,6 +18054,27 @@ static MVT getQDOTXResultType(MVT OpVT) {
1805418054
return MVT::getVectorVT(MVT::i32, OpEC.divideCoefficientBy(4));
1805518055
}
1805618056

18057+
/// Given fixed length vectors A and B with equal element types, but possibly
18058+
/// different number of elements, return A + B where either A or B is zero
18059+
/// padded to the larger number of elements.
18060+
static SDValue getZeroPaddedAdd(const SDLoc &DL, SDValue A, SDValue B,
18061+
SelectionDAG &DAG) {
18062+
// NOTE: Manually doing the extract/add/insert scheme produces
18063+
// significantly better codegen than the naive pad with zeros
18064+
// and add scheme.
18065+
EVT AVT = A.getValueType();
18066+
EVT BVT = B.getValueType();
18067+
assert(AVT.getVectorElementType() == BVT.getVectorElementType());
18068+
if (AVT.getVectorNumElements() > BVT.getVectorNumElements()) {
18069+
std::swap(A, B);
18070+
std::swap(AVT, BVT);
18071+
}
18072+
18073+
SDValue BPart = DAG.getExtractSubvector(DL, AVT, B, 0);
18074+
SDValue Res = DAG.getNode(ISD::ADD, DL, AVT, A, BPart);
18075+
return DAG.getInsertSubvector(DL, B, Res, 0);
18076+
}
18077+
1805718078
static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1805818079
SelectionDAG &DAG,
1805918080
const RISCVSubtarget &Subtarget,
@@ -18065,6 +18086,26 @@ static SDValue foldReduceOperandViaVQDOT(SDValue InVec, const SDLoc &DL,
1806518086
!InVec.getValueType().getVectorElementCount().isKnownMultipleOf(4))
1806618087
return SDValue();
1806718088

18089+
// Recurse through adds (since generic dag canonicalizes to that
18090+
// form). TODO: Handle disjoint or here.
18091+
if (InVec->getOpcode() == ISD::ADD) {
18092+
SDValue A = InVec.getOperand(0);
18093+
SDValue B = InVec.getOperand(1);
18094+
SDValue AOpt = foldReduceOperandViaVQDOT(A, DL, DAG, Subtarget, TLI);
18095+
SDValue BOpt = foldReduceOperandViaVQDOT(B, DL, DAG, Subtarget, TLI);
18096+
if (AOpt || BOpt) {
18097+
if (AOpt)
18098+
A = AOpt;
18099+
if (BOpt)
18100+
B = BOpt;
18101+
// From here, we're doing A + B with mixed types, implicitly zero
18102+
// padded to the wider type. Note that we *don't* need the result
18103+
// type to be the original VT, and in fact prefer narrower ones
18104+
// if possible.
18105+
return getZeroPaddedAdd(DL, A, B, DAG);
18106+
}
18107+
}
18108+
1806818109
// reduce (zext a) <--> reduce (mul zext a. zext 1)
1806918110
// reduce (sext a) <--> reduce (mul sext a. sext 1)
1807018111
if (InVec.getOpcode() == ISD::ZERO_EXTEND ||

llvm/test/CodeGen/RISCV/rvv/fixed-vectors-zvqdotq.ll

Lines changed: 102 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -299,17 +299,31 @@ entry:
299299
}
300300

301301
define i32 @vqdot_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
302-
; CHECK-LABEL: vqdot_vv_accum:
303-
; CHECK: # %bb.0: # %entry
304-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
305-
; CHECK-NEXT: vsext.vf2 v10, v8
306-
; CHECK-NEXT: vsext.vf2 v16, v9
307-
; CHECK-NEXT: vwmacc.vv v12, v10, v16
308-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
309-
; CHECK-NEXT: vmv.s.x v8, zero
310-
; CHECK-NEXT: vredsum.vs v8, v12, v8
311-
; CHECK-NEXT: vmv.x.s a0, v8
312-
; CHECK-NEXT: ret
302+
; NODOT-LABEL: vqdot_vv_accum:
303+
; NODOT: # %bb.0: # %entry
304+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
305+
; NODOT-NEXT: vsext.vf2 v10, v8
306+
; NODOT-NEXT: vsext.vf2 v16, v9
307+
; NODOT-NEXT: vwmacc.vv v12, v10, v16
308+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
309+
; NODOT-NEXT: vmv.s.x v8, zero
310+
; NODOT-NEXT: vredsum.vs v8, v12, v8
311+
; NODOT-NEXT: vmv.x.s a0, v8
312+
; NODOT-NEXT: ret
313+
;
314+
; DOT-LABEL: vqdot_vv_accum:
315+
; DOT: # %bb.0: # %entry
316+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
317+
; DOT-NEXT: vmv.v.i v10, 0
318+
; DOT-NEXT: vqdot.vv v10, v8, v9
319+
; DOT-NEXT: vadd.vv v8, v10, v12
320+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
321+
; DOT-NEXT: vmv.v.v v12, v8
322+
; DOT-NEXT: vmv.s.x v8, zero
323+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
324+
; DOT-NEXT: vredsum.vs v8, v12, v8
325+
; DOT-NEXT: vmv.x.s a0, v8
326+
; DOT-NEXT: ret
313327
entry:
314328
%a.sext = sext <16 x i8> %a to <16 x i32>
315329
%b.sext = sext <16 x i8> %b to <16 x i32>
@@ -320,17 +334,31 @@ entry:
320334
}
321335

322336
define i32 @vqdotu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
323-
; CHECK-LABEL: vqdotu_vv_accum:
324-
; CHECK: # %bb.0: # %entry
325-
; CHECK-NEXT: vsetivli zero, 16, e8, m1, ta, ma
326-
; CHECK-NEXT: vwmulu.vv v10, v8, v9
327-
; CHECK-NEXT: vsetvli zero, zero, e16, m2, ta, ma
328-
; CHECK-NEXT: vwaddu.wv v12, v12, v10
329-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
330-
; CHECK-NEXT: vmv.s.x v8, zero
331-
; CHECK-NEXT: vredsum.vs v8, v12, v8
332-
; CHECK-NEXT: vmv.x.s a0, v8
333-
; CHECK-NEXT: ret
337+
; NODOT-LABEL: vqdotu_vv_accum:
338+
; NODOT: # %bb.0: # %entry
339+
; NODOT-NEXT: vsetivli zero, 16, e8, m1, ta, ma
340+
; NODOT-NEXT: vwmulu.vv v10, v8, v9
341+
; NODOT-NEXT: vsetvli zero, zero, e16, m2, ta, ma
342+
; NODOT-NEXT: vwaddu.wv v12, v12, v10
343+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
344+
; NODOT-NEXT: vmv.s.x v8, zero
345+
; NODOT-NEXT: vredsum.vs v8, v12, v8
346+
; NODOT-NEXT: vmv.x.s a0, v8
347+
; NODOT-NEXT: ret
348+
;
349+
; DOT-LABEL: vqdotu_vv_accum:
350+
; DOT: # %bb.0: # %entry
351+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
352+
; DOT-NEXT: vmv.v.i v10, 0
353+
; DOT-NEXT: vqdotu.vv v10, v8, v9
354+
; DOT-NEXT: vadd.vv v8, v10, v12
355+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
356+
; DOT-NEXT: vmv.v.v v12, v8
357+
; DOT-NEXT: vmv.s.x v8, zero
358+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
359+
; DOT-NEXT: vredsum.vs v8, v12, v8
360+
; DOT-NEXT: vmv.x.s a0, v8
361+
; DOT-NEXT: ret
334362
entry:
335363
%a.zext = zext <16 x i8> %a to <16 x i32>
336364
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -341,17 +369,31 @@ entry:
341369
}
342370

343371
define i32 @vqdotsu_vv_accum(<16 x i8> %a, <16 x i8> %b, <16 x i32> %x) {
344-
; CHECK-LABEL: vqdotsu_vv_accum:
345-
; CHECK: # %bb.0: # %entry
346-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
347-
; CHECK-NEXT: vsext.vf2 v10, v8
348-
; CHECK-NEXT: vzext.vf2 v16, v9
349-
; CHECK-NEXT: vwmaccsu.vv v12, v10, v16
350-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
351-
; CHECK-NEXT: vmv.s.x v8, zero
352-
; CHECK-NEXT: vredsum.vs v8, v12, v8
353-
; CHECK-NEXT: vmv.x.s a0, v8
354-
; CHECK-NEXT: ret
372+
; NODOT-LABEL: vqdotsu_vv_accum:
373+
; NODOT: # %bb.0: # %entry
374+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
375+
; NODOT-NEXT: vsext.vf2 v10, v8
376+
; NODOT-NEXT: vzext.vf2 v16, v9
377+
; NODOT-NEXT: vwmaccsu.vv v12, v10, v16
378+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
379+
; NODOT-NEXT: vmv.s.x v8, zero
380+
; NODOT-NEXT: vredsum.vs v8, v12, v8
381+
; NODOT-NEXT: vmv.x.s a0, v8
382+
; NODOT-NEXT: ret
383+
;
384+
; DOT-LABEL: vqdotsu_vv_accum:
385+
; DOT: # %bb.0: # %entry
386+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
387+
; DOT-NEXT: vmv.v.i v10, 0
388+
; DOT-NEXT: vqdotsu.vv v10, v8, v9
389+
; DOT-NEXT: vadd.vv v8, v10, v12
390+
; DOT-NEXT: vsetivli zero, 4, e32, m4, tu, ma
391+
; DOT-NEXT: vmv.v.v v12, v8
392+
; DOT-NEXT: vmv.s.x v8, zero
393+
; DOT-NEXT: vsetivli zero, 16, e32, m4, ta, ma
394+
; DOT-NEXT: vredsum.vs v8, v12, v8
395+
; DOT-NEXT: vmv.x.s a0, v8
396+
; DOT-NEXT: ret
355397
entry:
356398
%a.sext = sext <16 x i8> %a to <16 x i32>
357399
%b.zext = zext <16 x i8> %b to <16 x i32>
@@ -455,20 +497,33 @@ entry:
455497
}
456498

457499
define i32 @vqdot_vv_split(<16 x i8> %a, <16 x i8> %b, <16 x i8> %c, <16 x i8> %d) {
458-
; CHECK-LABEL: vqdot_vv_split:
459-
; CHECK: # %bb.0: # %entry
460-
; CHECK-NEXT: vsetivli zero, 16, e16, m2, ta, ma
461-
; CHECK-NEXT: vsext.vf2 v12, v8
462-
; CHECK-NEXT: vsext.vf2 v14, v9
463-
; CHECK-NEXT: vsext.vf2 v16, v10
464-
; CHECK-NEXT: vsext.vf2 v18, v11
465-
; CHECK-NEXT: vwmul.vv v8, v12, v14
466-
; CHECK-NEXT: vwmacc.vv v8, v16, v18
467-
; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma
468-
; CHECK-NEXT: vmv.s.x v12, zero
469-
; CHECK-NEXT: vredsum.vs v8, v8, v12
470-
; CHECK-NEXT: vmv.x.s a0, v8
471-
; CHECK-NEXT: ret
500+
; NODOT-LABEL: vqdot_vv_split:
501+
; NODOT: # %bb.0: # %entry
502+
; NODOT-NEXT: vsetivli zero, 16, e16, m2, ta, ma
503+
; NODOT-NEXT: vsext.vf2 v12, v8
504+
; NODOT-NEXT: vsext.vf2 v14, v9
505+
; NODOT-NEXT: vsext.vf2 v16, v10
506+
; NODOT-NEXT: vsext.vf2 v18, v11
507+
; NODOT-NEXT: vwmul.vv v8, v12, v14
508+
; NODOT-NEXT: vwmacc.vv v8, v16, v18
509+
; NODOT-NEXT: vsetvli zero, zero, e32, m4, ta, ma
510+
; NODOT-NEXT: vmv.s.x v12, zero
511+
; NODOT-NEXT: vredsum.vs v8, v8, v12
512+
; NODOT-NEXT: vmv.x.s a0, v8
513+
; NODOT-NEXT: ret
514+
;
515+
; DOT-LABEL: vqdot_vv_split:
516+
; DOT: # %bb.0: # %entry
517+
; DOT-NEXT: vsetivli zero, 4, e32, m1, ta, ma
518+
; DOT-NEXT: vmv.v.i v12, 0
519+
; DOT-NEXT: vmv.v.i v13, 0
520+
; DOT-NEXT: vqdot.vv v12, v8, v9
521+
; DOT-NEXT: vqdot.vv v13, v10, v11
522+
; DOT-NEXT: vadd.vv v8, v12, v13
523+
; DOT-NEXT: vmv.s.x v9, zero
524+
; DOT-NEXT: vredsum.vs v8, v8, v9
525+
; DOT-NEXT: vmv.x.s a0, v8
526+
; DOT-NEXT: ret
472527
entry:
473528
%a.sext = sext <16 x i8> %a to <16 x i32>
474529
%b.sext = sext <16 x i8> %b to <16 x i32>

0 commit comments

Comments
 (0)