Skip to content

Commit 1ace9fa

Browse files
[LLVM][CodeGen][SVE] Enable Bfloat fma contraction. (#147941)
1 parent 44481f5 commit 1ace9fa

File tree

3 files changed

+25
-33
lines changed

3 files changed

+25
-33
lines changed

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17834,17 +17834,19 @@ bool AArch64TargetLowering::shouldConsiderGEPOffsetSplit() const {
1783417834

1783517835
bool AArch64TargetLowering::isFMAFasterThanFMulAndFAdd(
1783617836
const MachineFunction &MF, EVT VT) const {
17837-
VT = VT.getScalarType();
17837+
EVT ScalarVT = VT.getScalarType();
1783817838

17839-
if (!VT.isSimple())
17839+
if (!ScalarVT.isSimple())
1784017840
return false;
1784117841

17842-
switch (VT.getSimpleVT().SimpleTy) {
17842+
switch (ScalarVT.getSimpleVT().SimpleTy) {
1784317843
case MVT::f16:
1784417844
return Subtarget->hasFullFP16();
1784517845
case MVT::f32:
1784617846
case MVT::f64:
1784717847
return true;
17848+
case MVT::bf16:
17849+
return VT.isScalableVector() && Subtarget->hasSVEB16B16();
1784817850
default:
1784917851
break;
1785017852
}

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2490,6 +2490,8 @@ multiclass sve_fp_3op_p_zds_a_bfloat<bits<2> opc, string asm, string Ps,
24902490
SVEPseudo2Instr<Ps, 1>, SVEInstr2Rev<NAME, "", 0>;
24912491

24922492
def : SVE_4_Op_Pat<nxv8bf16, op, nxv8i1, nxv8bf16, nxv8bf16, nxv8bf16, !cast<Instruction>(NAME)>;
2493+
def : SVE_4_Op_Pat<nxv4bf16, op, nxv4i1, nxv4bf16, nxv4bf16, nxv4bf16, !cast<Instruction>(NAME)>;
2494+
def : SVE_4_Op_Pat<nxv2bf16, op, nxv2i1, nxv2bf16, nxv2bf16, nxv2bf16, !cast<Instruction>(NAME)>;
24932495
}
24942496

24952497
class sve_fp_3op_p_zds_b<bits<2> sz, bits<2> opc, string asm,

llvm/test/CodeGen/AArch64/sve-bf16-combines.ll

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ target triple = "aarch64-unknown-linux-gnu"
66
define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
77
; CHECK-LABEL: fmla_nxv8bf16:
88
; CHECK: // %bb.0:
9-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
10-
; CHECK-NEXT: bfadd z0.h, z0.h, z1.h
9+
; CHECK-NEXT: ptrue p0.h
10+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
1111
; CHECK-NEXT: ret
1212
%mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
1313
%res = fadd contract <vscale x 8 x bfloat> %acc, %mul
@@ -17,8 +17,8 @@ define <vscale x 8 x bfloat> @fmla_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale
1717
define <vscale x 4 x bfloat> @fmla_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
1818
; CHECK-LABEL: fmla_nxv4bf16:
1919
; CHECK: // %bb.0:
20-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
21-
; CHECK-NEXT: bfadd z0.h, z0.h, z1.h
20+
; CHECK-NEXT: ptrue p0.s
21+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
2222
; CHECK-NEXT: ret
2323
%mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
2424
%res = fadd contract <vscale x 4 x bfloat> %acc, %mul
@@ -28,8 +28,8 @@ define <vscale x 4 x bfloat> @fmla_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale
2828
define <vscale x 2 x bfloat> @fmla_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
2929
; CHECK-LABEL: fmla_nxv2bf16:
3030
; CHECK: // %bb.0:
31-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
32-
; CHECK-NEXT: bfadd z0.h, z0.h, z1.h
31+
; CHECK-NEXT: ptrue p0.d
32+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
3333
; CHECK-NEXT: ret
3434
%mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
3535
%res = fadd contract <vscale x 2 x bfloat> %acc, %mul
@@ -39,8 +39,8 @@ define <vscale x 2 x bfloat> @fmla_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale
3939
define <vscale x 8 x bfloat> @fmls_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
4040
; CHECK-LABEL: fmls_nxv8bf16:
4141
; CHECK: // %bb.0:
42-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
43-
; CHECK-NEXT: bfsub z0.h, z0.h, z1.h
42+
; CHECK-NEXT: ptrue p0.h
43+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
4444
; CHECK-NEXT: ret
4545
%mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
4646
%res = fsub contract <vscale x 8 x bfloat> %acc, %mul
@@ -50,8 +50,8 @@ define <vscale x 8 x bfloat> @fmls_nxv8bf16(<vscale x 8 x bfloat> %acc, <vscale
5050
define <vscale x 4 x bfloat> @fmls_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
5151
; CHECK-LABEL: fmls_nxv4bf16:
5252
; CHECK: // %bb.0:
53-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
54-
; CHECK-NEXT: bfsub z0.h, z0.h, z1.h
53+
; CHECK-NEXT: ptrue p0.s
54+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
5555
; CHECK-NEXT: ret
5656
%mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
5757
%res = fsub contract <vscale x 4 x bfloat> %acc, %mul
@@ -61,8 +61,8 @@ define <vscale x 4 x bfloat> @fmls_nxv4bf16(<vscale x 4 x bfloat> %acc, <vscale
6161
define <vscale x 2 x bfloat> @fmls_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
6262
; CHECK-LABEL: fmls_nxv2bf16:
6363
; CHECK: // %bb.0:
64-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
65-
; CHECK-NEXT: bfsub z0.h, z0.h, z1.h
64+
; CHECK-NEXT: ptrue p0.d
65+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
6666
; CHECK-NEXT: ret
6767
%mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
6868
%res = fsub contract <vscale x 2 x bfloat> %acc, %mul
@@ -72,9 +72,7 @@ define <vscale x 2 x bfloat> @fmls_nxv2bf16(<vscale x 2 x bfloat> %acc, <vscale
7272
define <vscale x 8 x bfloat> @fmla_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
7373
; CHECK-LABEL: fmla_sel_nxv8bf16:
7474
; CHECK: // %bb.0:
75-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
76-
; CHECK-NEXT: bfadd z1.h, z0.h, z1.h
77-
; CHECK-NEXT: mov z0.h, p0/m, z1.h
75+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
7876
; CHECK-NEXT: ret
7977
%mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
8078
%add = fadd contract <vscale x 8 x bfloat> %acc, %mul
@@ -85,9 +83,7 @@ define <vscale x 8 x bfloat> @fmla_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale
8583
define <vscale x 4 x bfloat> @fmla_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
8684
; CHECK-LABEL: fmla_sel_nxv4bf16:
8785
; CHECK: // %bb.0:
88-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
89-
; CHECK-NEXT: bfadd z1.h, z0.h, z1.h
90-
; CHECK-NEXT: mov z0.s, p0/m, z1.s
86+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
9187
; CHECK-NEXT: ret
9288
%mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
9389
%add = fadd contract <vscale x 4 x bfloat> %acc, %mul
@@ -98,9 +94,7 @@ define <vscale x 4 x bfloat> @fmla_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale
9894
define <vscale x 2 x bfloat> @fmla_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
9995
; CHECK-LABEL: fmla_sel_nxv2bf16:
10096
; CHECK: // %bb.0:
101-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
102-
; CHECK-NEXT: bfadd z1.h, z0.h, z1.h
103-
; CHECK-NEXT: mov z0.d, p0/m, z1.d
97+
; CHECK-NEXT: bfmla z0.h, p0/m, z1.h, z2.h
10498
; CHECK-NEXT: ret
10599
%mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
106100
%add = fadd contract <vscale x 2 x bfloat> %acc, %mul
@@ -111,9 +105,7 @@ define <vscale x 2 x bfloat> @fmla_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale
111105
define <vscale x 8 x bfloat> @fmls_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale x 8 x bfloat> %acc, <vscale x 8 x bfloat> %m1, <vscale x 8 x bfloat> %m2) {
112106
; CHECK-LABEL: fmls_sel_nxv8bf16:
113107
; CHECK: // %bb.0:
114-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
115-
; CHECK-NEXT: bfsub z1.h, z0.h, z1.h
116-
; CHECK-NEXT: mov z0.h, p0/m, z1.h
108+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
117109
; CHECK-NEXT: ret
118110
%mul = fmul contract <vscale x 8 x bfloat> %m1, %m2
119111
%sub = fsub contract <vscale x 8 x bfloat> %acc, %mul
@@ -124,9 +116,7 @@ define <vscale x 8 x bfloat> @fmls_sel_nxv8bf16(<vscale x 8 x i1> %pred, <vscale
124116
define <vscale x 4 x bfloat> @fmls_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale x 4 x bfloat> %acc, <vscale x 4 x bfloat> %m1, <vscale x 4 x bfloat> %m2) {
125117
; CHECK-LABEL: fmls_sel_nxv4bf16:
126118
; CHECK: // %bb.0:
127-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
128-
; CHECK-NEXT: bfsub z1.h, z0.h, z1.h
129-
; CHECK-NEXT: mov z0.s, p0/m, z1.s
119+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
130120
; CHECK-NEXT: ret
131121
%mul = fmul contract <vscale x 4 x bfloat> %m1, %m2
132122
%sub = fsub contract <vscale x 4 x bfloat> %acc, %mul
@@ -137,9 +127,7 @@ define <vscale x 4 x bfloat> @fmls_sel_nxv4bf16(<vscale x 4 x i1> %pred, <vscale
137127
define <vscale x 2 x bfloat> @fmls_sel_nxv2bf16(<vscale x 2 x i1> %pred, <vscale x 2 x bfloat> %acc, <vscale x 2 x bfloat> %m1, <vscale x 2 x bfloat> %m2) {
138128
; CHECK-LABEL: fmls_sel_nxv2bf16:
139129
; CHECK: // %bb.0:
140-
; CHECK-NEXT: bfmul z1.h, z1.h, z2.h
141-
; CHECK-NEXT: bfsub z1.h, z0.h, z1.h
142-
; CHECK-NEXT: mov z0.d, p0/m, z1.d
130+
; CHECK-NEXT: bfmls z0.h, p0/m, z1.h, z2.h
143131
; CHECK-NEXT: ret
144132
%mul = fmul contract <vscale x 2 x bfloat> %m1, %m2
145133
%sub = fsub contract <vscale x 2 x bfloat> %acc, %mul

0 commit comments

Comments
 (0)