From 8279e68805e779cfe7f599ebc9ab1e9a427b5eaf Mon Sep 17 00:00:00 2001 From: Sharif Inamdar Date: Mon, 9 Jun 2025 09:25:45 +0000 Subject: [PATCH] Optimize gemv_n_sve_v1x3 kernel - Calculate predicate outside the loop - Divide matrix in blocks of 3 --- CONTRIBUTORS.md | 5 +- kernel/arm64/gemv_n_sve_v1x3.c | 96 ++++++++++++++++++++-------------- 2 files changed, 61 insertions(+), 40 deletions(-) diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index d8f57ef60a..35519e9001 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -253,4 +253,7 @@ In chronological order: * [2025-02-27] Add sbgemv_n_neon kernel * Abhishek Kumar - * [2025-04-22] Optimise dot kernel for NEOVERSE V1 \ No newline at end of file + * [2025-04-22] Optimise dot kernel for NEOVERSE V1 + +* Sharif Inamdar + * [2025-06-05] Optimize gemv_n_sve_v1x3 kernel \ No newline at end of file diff --git a/kernel/arm64/gemv_n_sve_v1x3.c b/kernel/arm64/gemv_n_sve_v1x3.c index d6aa3d3894..5ab8d3a166 100644 --- a/kernel/arm64/gemv_n_sve_v1x3.c +++ b/kernel/arm64/gemv_n_sve_v1x3.c @@ -52,17 +52,17 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, BLASLONG lda, FLOAT *x, BLASLONG inc_x, FLOAT *y, BLASLONG inc_y, FLOAT *buffer) { - BLASLONG i; - BLASLONG ix,iy; - BLASLONG j; - FLOAT *a_ptr; + BLASLONG i, j; + BLASLONG ix = 0; + BLASLONG iy; + FLOAT *a_ptr = a; FLOAT temp; - ix = 0; - a_ptr = a; - if (inc_y == 1) { - BLASLONG width = (n + 3 - 1) / 3; + BLASLONG width = n / 3; // Only process full 3-column blocks + BLASLONG sve_size = SV_COUNT(); + svbool_t pg_full = SV_TRUE(); + svbool_t pg_tail = SV_WHILE(0, m % sve_size); FLOAT *a0_ptr = a_ptr + lda * width * 0; FLOAT *a1_ptr = a_ptr + lda * width * 1; @@ -73,57 +73,75 @@ int CNAME(BLASLONG m, BLASLONG n, BLASLONG dummy1, FLOAT alpha, FLOAT *a, FLOAT *x2_ptr = x + inc_x * width * 2; for (j = 0; j < width; j++) { - svbool_t pg00 = ((j + width * 0) < n) ? SV_TRUE() : svpfalse(); - svbool_t pg01 = ((j + width * 1) < n) ? SV_TRUE() : svpfalse(); - svbool_t pg02 = ((j + width * 2) < n) ? SV_TRUE() : svpfalse(); + SV_TYPE temp0_vec = SV_DUP(alpha * x0_ptr[ix]); + SV_TYPE temp1_vec = SV_DUP(alpha * x1_ptr[ix]); + SV_TYPE temp2_vec = SV_DUP(alpha * x2_ptr[ix]); - SV_TYPE temp0_vec = ((j + width * 0) < n) ? SV_DUP(alpha * x0_ptr[ix]) : SV_DUP(0.0); - SV_TYPE temp1_vec = ((j + width * 1) < n) ? SV_DUP(alpha * x1_ptr[ix]) : SV_DUP(0.0); - SV_TYPE temp2_vec = ((j + width * 2) < n) ? SV_DUP(alpha * x2_ptr[ix]) : SV_DUP(0.0); i = 0; - BLASLONG sve_size = SV_COUNT(); - while ((i + sve_size * 1 - 1) < m) { - SV_TYPE y0_vec = svld1_vnum(SV_TRUE(), y + i, 0); + while ((i + sve_size - 1) < m) { + SV_TYPE y0_vec = svld1(pg_full, y + i); - SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0); - SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0); - SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0); + SV_TYPE a00_vec = svld1(pg_full, a0_ptr + i); + SV_TYPE a01_vec = svld1(pg_full, a1_ptr + i); + SV_TYPE a02_vec = svld1(pg_full, a2_ptr + i); - y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec); - y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec); - y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec); + y0_vec = svmla_x(pg_full, y0_vec, temp0_vec, a00_vec); + y0_vec = svmla_x(pg_full, y0_vec, temp1_vec, a01_vec); + y0_vec = svmla_x(pg_full, y0_vec, temp2_vec, a02_vec); - svst1_vnum(SV_TRUE(), y + i, 0, y0_vec); - i += sve_size * 1; + svst1(pg_full, y + i, y0_vec); + i += sve_size; } if (i < m) { - svbool_t pg0 = SV_WHILE(i + sve_size * 0, m); - - pg00 = svand_z(SV_TRUE(), pg0, pg00); - pg01 = svand_z(SV_TRUE(), pg0, pg01); - pg02 = svand_z(SV_TRUE(), pg0, pg02); + SV_TYPE y0_vec = svld1(pg_tail, y + i); - SV_TYPE y0_vec = svld1_vnum(pg0, y + i, 0); + SV_TYPE a00_vec = svld1(pg_tail, a0_ptr + i); + SV_TYPE a01_vec = svld1(pg_tail, a1_ptr + i); + SV_TYPE a02_vec = svld1(pg_tail, a2_ptr + i); - SV_TYPE a00_vec = svld1_vnum(pg00, a0_ptr + i, 0); - SV_TYPE a01_vec = svld1_vnum(pg01, a1_ptr + i, 0); - SV_TYPE a02_vec = svld1_vnum(pg02, a2_ptr + i, 0); + y0_vec = svmla_m(pg_tail, y0_vec, temp0_vec, a00_vec); + y0_vec = svmla_m(pg_tail, y0_vec, temp1_vec, a01_vec); + y0_vec = svmla_m(pg_tail, y0_vec, temp2_vec, a02_vec); - y0_vec = svmla_m(pg00, y0_vec, temp0_vec, a00_vec); - y0_vec = svmla_m(pg01, y0_vec, temp1_vec, a01_vec); - y0_vec = svmla_m(pg02, y0_vec, temp2_vec, a02_vec); - - svst1_vnum(pg0, y + i, 0, y0_vec); + svst1(pg_tail, y + i, y0_vec); } a0_ptr += lda; a1_ptr += lda; a2_ptr += lda; ix += inc_x; } + // Handle remaining n % 3 columns + for (j = width * 3; j < n; j++) { + FLOAT *a_col = a + j * lda; + temp = alpha * x[j * inc_x]; + SV_TYPE temp_vec = SV_DUP(temp); + + i = 0; + while ((i + sve_size - 1) < m) { + SV_TYPE y_vec = svld1(pg_full, y + i); + + SV_TYPE a_vec = svld1(pg_full, a_col + i); + + y_vec = svmla_x(pg_full, y_vec, temp_vec, a_vec); + + svst1(pg_full, y + i, y_vec); + i += sve_size; + } + if (i < m) { + SV_TYPE y_vec = svld1(pg_tail, y + i); + + SV_TYPE a_vec = svld1(pg_tail, a_col + i); + + y_vec = svmla_m(pg_tail, y_vec, temp_vec, a_vec); + + svst1(pg_tail, y + i, y_vec); + } + } return(0); } + // Fallback scalar loop for (j = 0; j < n; j++) { temp = alpha * x[ix]; iy = 0;