Skip to content

Commit 53d20a8

Browse files
authored
Merge pull request #5089 from annop-w/gemv_t
Simplify gemv_t_sve_v1x3 kernel
2 parents 9b11fd5 + 6e393a5 commit 53d20a8

File tree

2 files changed

+57
-39
lines changed

2 files changed

+57
-39
lines changed

CONTRIBUTORS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,8 @@ In chronological order:
235235

236236
* Annop Wongwathanarat <annop.wongwathanarat@arm.com>
237237
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1
238+
* [2025-01-21] Optimize gemv_t_sve_v1x3 kernel
238239

239240
* Marek Michalowski <https://github.com/michalowski-arm>
240241
* [2025-01-21] Add thread throttling profile for SGEMV on `NEOVERSEV1`
242+

kernel/arm64/gemv_t_sve_v1x3.c

Lines changed: 55 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/***************************************************************************
2-
Copyright (c) 2024, The OpenBLAS Project
2+
Copyright (c) 2024, 2025 The OpenBLAS Project
33
All rights reserved.
44
55
Redistribution and use in source and binary forms, with or without
@@ -56,12 +56,16 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
5656
BLASLONG ix,iy;
5757
BLASLONG j;
5858
FLOAT *a_ptr;
59+
FLOAT *y_ptr;
5960
FLOAT temp;
6061

6162
iy = 0;
6263

6364
if (inc_x == 1) {
64-
BLASLONG width = (n + 3 - 1) / 3;
65+
BLASLONG width = n / 3;
66+
BLASLONG sve_size = SV_COUNT();
67+
svbool_t pg_true = SV_TRUE();
68+
svbool_t pg = SV_WHILE(0, m % sve_size);
6569

6670
FLOAT *a0_ptr = a + lda * width * 0;
6771
FLOAT *a1_ptr = a + lda * width * 1;
@@ -72,67 +76,79 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
7276
FLOAT *y2_ptr = y + inc_y * width * 2;
7377

7478
for (j = 0; j < width; j++) {
75-
svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
76-
svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
77-
svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
78-
7979
SV_TYPE temp00_vec = SV_DUP(0.0);
8080
SV_TYPE temp01_vec = SV_DUP(0.0);
8181
SV_TYPE temp02_vec = SV_DUP(0.0);
8282

8383
i = 0;
84-
BLASLONG sve_size = SV_COUNT();
8584
while ((i + sve_size * 1 - 1) < m) {
86-
SV_TYPE x0_vec = svld1_vnum(SV_TRUE(), x + i, 0);
85+
SV_TYPE x0_vec = svld1(pg_true, x + i);
8786

88-
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
89-
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
90-
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
87+
SV_TYPE a00_vec = svld1(pg_true, a0_ptr + i);
88+
SV_TYPE a01_vec = svld1(pg_true, a1_ptr + i);
89+
SV_TYPE a02_vec = svld1(pg_true, a2_ptr + i);
9190

92-
temp00_vec = svmla_m(pg00, temp00_vec, a00_vec, x0_vec);
93-
temp01_vec = svmla_m(pg01, temp01_vec, a01_vec, x0_vec);
94-
temp02_vec = svmla_m(pg02, temp02_vec, a02_vec, x0_vec);
91+
temp00_vec = svmla_x(pg_true, temp00_vec, a00_vec, x0_vec);
92+
temp01_vec = svmla_x(pg_true, temp01_vec, a01_vec, x0_vec);
93+
temp02_vec = svmla_x(pg_true, temp02_vec, a02_vec, x0_vec);
9594

9695
i += sve_size * 1;
9796
}
9897

9998
if (i < m) {
100-
svbool_t pg0 = SV_WHILE(i + sve_size * 0, m);
101-
102-
pg00 = svand_z(SV_TRUE(), pg0, pg00);
103-
pg01 = svand_z(SV_TRUE(), pg0, pg01);
104-
pg02 = svand_z(SV_TRUE(), pg0, pg02);
99+
SV_TYPE x0_vec = svld1(pg, x + i);
105100

106-
SV_TYPE x0_vec = svld1_vnum(pg0, x + i, 0);
101+
SV_TYPE a00_vec = svld1(pg, a0_ptr + i);
102+
SV_TYPE a01_vec = svld1(pg, a1_ptr + i);
103+
SV_TYPE a02_vec = svld1(pg, a2_ptr + i);
107104

108-
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
109-
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
110-
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
111-
112-
temp00_vec = svmla_m(pg00, temp00_vec, a00_vec, x0_vec);
113-
temp01_vec = svmla_m(pg01, temp01_vec, a01_vec, x0_vec);
114-
temp02_vec = svmla_m(pg02, temp02_vec, a02_vec, x0_vec);
105+
temp00_vec = svmla_m(pg, temp00_vec, a00_vec, x0_vec);
106+
temp01_vec = svmla_m(pg, temp01_vec, a01_vec, x0_vec);
107+
temp02_vec = svmla_m(pg, temp02_vec, a02_vec, x0_vec);
115108
}
116109

117-
if ((j + width * 0) < n) {
118-
temp = svaddv(SV_TRUE(), temp00_vec);
119-
y0_ptr[iy] += alpha * temp;
120-
}
121-
if ((j + width * 1) < n) {
122-
temp = svaddv(SV_TRUE(), temp01_vec);
123-
y1_ptr[iy] += alpha * temp;
124-
}
125-
if ((j + width * 2) < n) {
126-
temp = svaddv(SV_TRUE(), temp02_vec);
127-
y2_ptr[iy] += alpha * temp;
128-
}
110+
y0_ptr[iy] += alpha * svaddv(pg_true, temp00_vec);
111+
y1_ptr[iy] += alpha * svaddv(pg_true, temp01_vec);
112+
y2_ptr[iy] += alpha * svaddv(pg_true, temp02_vec);
113+
129114
iy += inc_y;
130115

131116
a0_ptr += lda;
132117
a1_ptr += lda;
133118
a2_ptr += lda;
134119
}
135120

121+
a_ptr = a2_ptr;
122+
y_ptr = y2_ptr;
123+
for (j = width * 3; j < n; j++) {
124+
SV_TYPE temp_vec = SV_DUP(0.0);
125+
126+
i = 0;
127+
while ((i + sve_size * 1 - 1) < m) {
128+
SV_TYPE x_vec = svld1(pg_true, x + i);
129+
130+
SV_TYPE a_vec = svld1(pg_true, a_ptr + i);
131+
132+
temp_vec = svmla_x(pg_true, temp_vec, a_vec, x_vec);
133+
134+
i += sve_size * 1;
135+
}
136+
137+
if (i < m) {
138+
SV_TYPE x_vec = svld1(pg, x + i);
139+
140+
SV_TYPE a_vec = svld1(pg, a_ptr + i);
141+
142+
temp_vec = svmla_m(pg, temp_vec, a_vec, x_vec);
143+
}
144+
145+
y_ptr[iy] += alpha * svaddv(pg_true, temp_vec);
146+
147+
iy += inc_y;
148+
149+
a_ptr += lda;
150+
}
151+
136152
return(0);
137153
}
138154

0 commit comments

Comments
 (0)