39
39
40
40
#include <stdio.h>
41
41
#include <stdlib.h>
42
+ #include <stdbool.h>
42
43
#include "common.h"
43
44
#ifdef FUNCTION_PROFILE
44
45
#include "functable.h"
@@ -499,6 +500,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
499
500
#endif
500
501
501
502
#if defined(GEMM_GEMV_FORWARD ) && !defined(GEMM3M ) && !defined(COMPLEX ) && (!defined(BFLOAT16 ) || defined(GEMM_GEMV_FORWARD_BF16 ))
503
+ #if defined(ARCH_ARM64 )
504
+ // The gemv kernels in arm64/{gemv_n.S,gemv_n_sve.c,gemv_t.S,gemv_t_sve.c}
505
+ // perform poorly in certain circumstances. We use the following boolean
506
+ // variable along with the gemv argument values to avoid these inefficient
507
+ // gemv cases, see github issue#4951.
508
+ bool have_tuned_gemv = false;
509
+ #else
510
+ bool have_tuned_gemv = true;
511
+ #endif
502
512
// Check if we can convert GEMM -> GEMV
503
513
if (args .k != 0 ) {
504
514
if (args .n == 1 ) {
@@ -518,8 +528,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
518
528
if (transb & 1 ) {
519
529
inc_x = args .ldb ;
520
530
}
521
- GEMV (& NT , & m , & n , args .alpha , args .a , & lda , args .b , & inc_x , args .beta , args .c , & inc_y );
522
- return ;
531
+ bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' ) || (NT == 'T' && inc_x == 1 ));
532
+ if (is_efficient_gemv ) {
533
+ GEMV (& NT , & m , & n , args .alpha , args .a , & lda , args .b , & inc_x , args .beta , args .c , & inc_y );
534
+ return ;
535
+ }
523
536
}
524
537
if (args .m == 1 ) {
525
538
blasint inc_x = args .lda ;
@@ -538,8 +551,11 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
538
551
m = args .n ;
539
552
n = args .k ;
540
553
}
541
- GEMV (& NT , & m , & n , args .alpha , args .b , & ldb , args .a , & inc_x , args .beta , args .c , & inc_y );
542
- return ;
554
+ bool is_efficient_gemv = have_tuned_gemv || ((NT == 'N' && inc_y == 1 ) || (NT == 'T' && inc_x == 1 ));
555
+ if (is_efficient_gemv ) {
556
+ GEMV (& NT , & m , & n , args .alpha , args .b , & ldb , args .a , & inc_x , args .beta , args .c , & inc_y );
557
+ return ;
558
+ }
543
559
}
544
560
}
545
561
#endif
0 commit comments