Skip to content

Commit 1ed7eb6

Browse files
committed
Optimize gemv_n_sve_v1x3 kernel
- Calculate predicate outside the loop - Divide matrix in blocks of 3
1 parent 02267d8 commit 1ed7eb6

File tree

2 files changed

+63
-39
lines changed

2 files changed

+63
-39
lines changed

CONTRIBUTORS.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,4 +253,7 @@ In chronological order:
253253
* [2025-02-27] Add sbgemv_n_neon kernel
254254

255255
* Abhishek Kumar <https://github.com/abhishek-iitmadras>
256-
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
256+
* [2025-04-22] Optimise dot kernel for NEOVERSE V1
257+
258+
* Sharif Inamdar <sharif.inamdar@arm.com>
259+
* [2025-06-05] Optimize gemv_n_sve_v1x3 kernel

kernel/arm64/gemv_n_sve_v1x3.c

Lines changed: 59 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -52,17 +52,16 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
5252
BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y,
5353
FLOAT *buffer)
5454
{
55-
BLASLONG i;
56-
BLASLONG ix,iy;
57-
BLASLONG j;
58-
FLOAT *a_ptr;
55+
BLASLONG i, j;
56+
BLASLONG ix = 0;
57+
BLASLONG iy;
58+
FLOAT *a_ptr = a;
5959
FLOAT temp;
6060

61-
ix = 0;
62-
a_ptr = a;
63-
6461
if (inc_y == 1) {
65-
BLASLONG width = (n + 3 - 1) / 3;
62+
BLASLONG width = n / 3; // Only process full 3-column blocks
63+
BLASLONG sve_size = SV_COUNT();
64+
svbool_t pg_full = SV_TRUE();
6665

6766
FLOAT *a0_ptr = a_ptr + lda * width * 0;
6867
FLOAT *a1_ptr = a_ptr + lda * width * 1;
@@ -73,57 +72,79 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a,
7372
FLOAT *x2_ptr = x + inc_x * width * 2;
7473

7574
for (j = 0; j < width; j++) {
76-
svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse();
77-
svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse();
78-
svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse();
75+
SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]);
76+
SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]);
77+
SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]);
7978

80-
SV_TYPE temp0_vec = ((j + width * 0) < n) ? SV_DUP(alpha * x0_ptr[ix]) : SV_DUP(0.0);
81-
SV_TYPE temp1_vec = ((j + width * 1) < n) ? SV_DUP(alpha * x1_ptr[ix]) : SV_DUP(0.0);
82-
SV_TYPE temp2_vec = ((j + width * 2) < n) ? SV_DUP(alpha * x2_ptr[ix]) : SV_DUP(0.0);
8379
i = 0;
84-
BLASLONG sve_size = SV_COUNT();
85-
while ((i + sve_size * 1 - 1) < m) {
86-
SV_TYPE y0_vec = svld1_vnum(SV_TRUE(), y + i, 0);
80+
while ((i + sve_size - 1) < m) {
81+
SV_TYPE y0_vec = svld1(pg_full, y + i);
8782

88-
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
89-
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
90-
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
83+
SV_TYPE a00_vec = svld1(pg_full, a0_ptr + i);
84+
SV_TYPE a01_vec = svld1(pg_full, a1_ptr + i);
85+
SV_TYPE a02_vec = svld1(pg_full, a2_ptr + i);
9186

92-
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
93-
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
94-
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
87+
y0_vec = svmla_x(pg_full, y0_vec, temp0_vec, a00_vec);
88+
y0_vec = svmla_x(pg_full, y0_vec, temp1_vec, a01_vec);
89+
y0_vec = svmla_x(pg_full, y0_vec, temp2_vec, a02_vec);
9590

96-
svst1_vnum(SV_TRUE(), y + i, 0, y0_vec);
97-
i += sve_size * 1;
91+
svst1(pg_full, y + i, y0_vec);
92+
i += sve_size;
9893
}
9994

10095
if (i < m) {
101-
svbool_t pg0 = SV_WHILE(i + sve_size * 0, m);
102-
103-
pg00 = svand_z(SV_TRUE(), pg0, pg00);
104-
pg01 = svand_z(SV_TRUE(), pg0, pg01);
105-
pg02 = svand_z(SV_TRUE(), pg0, pg02);
96+
svbool_t pg_tail = SV_WHILE(i, m);
10697

107-
SV_TYPE y0_vec = svld1_vnum(pg0, y + i, 0);
98+
SV_TYPE y0_vec = svld1(pg_tail, y + i);
10899

109-
SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0);
110-
SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0);
111-
SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0);
100+
SV_TYPE a00_vec = svld1(pg_tail, a0_ptr + i);
101+
SV_TYPE a01_vec = svld1(pg_tail, a1_ptr + i);
102+
SV_TYPE a02_vec = svld1(pg_tail, a2_ptr + i);
112103

113-
y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec);
114-
y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec);
115-
y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec);
104+
y0_vec = svmla_m(pg_tail, y0_vec, temp0_vec, a00_vec);
105+
y0_vec = svmla_m(pg_tail, y0_vec, temp1_vec, a01_vec);
106+
y0_vec = svmla_m(pg_tail, y0_vec, temp2_vec, a02_vec);
116107

117-
svst1_vnum(pg0, y + i, 0, y0_vec);
108+
svst1(pg_tail, y + i, y0_vec);
118109
}
119110
a0_ptr += lda;
120111
a1_ptr += lda;
121112
a2_ptr += lda;
122113
ix += inc_x;
123114
}
115+
// Handle remaining n % 3 columns
116+
for (j = width * 3; j < n; j++) {
117+
FLOAT *a_col = a + j * lda;
118+
temp = alpha * x[j * inc_x];
119+
SV_TYPE temp_vec = SV_DUP(temp);
120+
121+
i = 0;
122+
while ((i + sve_size - 1) < m) {
123+
SV_TYPE y_vec = svld1(pg_full, y + i);
124+
125+
SV_TYPE a_vec = svld1(pg_full, a_col + i);
126+
127+
y_vec = svmla_x(pg_full, y_vec, temp_vec, a_vec);
128+
129+
svst1(pg_full, y + i, y_vec);
130+
i += sve_size;
131+
}
132+
if (i < m) {
133+
svbool_t pg_tail = SV_WHILE(i, m);
134+
135+
SV_TYPE y_vec = svld1(pg_tail, y + i);
136+
137+
SV_TYPE a_vec = svld1(pg_tail, a_col + i);
138+
139+
y_vec = svmla_m(pg_tail, y_vec, temp_vec, a_vec);
140+
141+
svst1(pg_tail, y + i, y_vec);
142+
}
143+
}
124144
return(0);
125145
}
126146

147+
// Fallback scalar loop
127148
for (j = 0; j < n; j++) {
128149
temp = alpha * x[ix];
129150
iy = 0;

0 commit comments

Comments
 (0)