Skip to content

Commit cb48505

Browse files
author
Chris Daley
committed
optimize gemv forwarding on ARM64 systems
1 parent 72461f1 commit cb48505

File tree

2 files changed

+23
-4
lines changed

2 files changed

+23
-4
lines changed

CONTRIBUTORS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -226,3 +226,6 @@ In chronological order:
226226

227227
* Dirreke <https://github.com/mseminatore>
228228
* [2024-01-16] Add basic support for the CSKY architecture
229+
230+
* Christopher Daley <https://github.com/cdaley>
231+
* [2024-01-24] Optimize GEMV forwarding on ARM64 systems

interface/gemm.c

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939

4040
#include <stdio.h>
4141
#include <stdlib.h>
42+
#include <stdbool.h>
4243
#include "common.h"
4344
#ifdef FUNCTION_PROFILE
4445
#include "functable.h"
@@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
499500
#endif
500501

501502
#if defined(GEMM_GEMV_FORWARD) && !defined(GEMM3M) && !defined(COMPLEX) && (!defined(BFLOAT16) || defined(GEMM_GEMV_FORWARD_BF16))
503+
#if defined(ARCH_ARM64)
504+
// The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
505+
// perform poorly in certain circumstances. We use the following boolean
506+
// variable along with the gemv argument values to avoid these inefficient
507+
// gemv cases, see github issue#4951.
508+
bool have_tuned_gemv = false;
509+
#else
510+
bool have_tuned_gemv = true;
511+
#endif
502512
// Check if we can convert GEMM -> GEMV
503513
if (args.k != 0) {
504514
if (args.n == 1) {
@@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
518528
if (transb & 1) {
519529
inc_x = args.ldb;
520530
}
521-
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
522-
return;
531+
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N') || (NT == 'T' && inc_x == 1));
532+
if (is_efficient_gemv) {
533+
GEMV(&NT, &m, &n, args.alpha, args.a, &lda, args.b, &inc_x, args.beta, args.c, &inc_y);
534+
return;
535+
}
523536
}
524537
if (args.m == 1) {
525538
blasint inc_x = args.lda;
@@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
538551
m = args.n;
539552
n = args.k;
540553
}
541-
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
542-
return;
554+
bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1) || (NT == 'T' && inc_x == 1));
555+
if (is_efficient_gemv) {
556+
GEMV(&NT, &m, &n, args.alpha, args.b, &ldb, args.a, &inc_x, args.beta, args.c, &inc_y);
557+
return;
558+
}
543559
}
544560
}
545561
#endif

0 commit comments

Comments
 (0)