Skip to content

Commit 3db5dbc

Browse files
martin-frbgMousius
authored andcommitted
forward to GEMV when one argument is actually a vector
1 parent 136a4ed commit 3db5dbc

File tree

1 file changed

+45
-4
lines changed

1 file changed

+45
-4
lines changed

interface/gemm.c

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,22 +47,29 @@
4747
#define SMP_THRESHOLD_MIN 65536.0
4848
#ifdef XDOUBLE
4949
#define ERROR_NAME "QGEMM "
50+
#define GEMV BLASFUNC(qgemv)
5051
#elif defined(DOUBLE)
5152
#define ERROR_NAME "DGEMM "
53+
#define GEMV BLASFUNC(dgemv)
5254
#elif defined(BFLOAT16)
5355
#define ERROR_NAME "SBGEMM "
56+
#define GEMV BLASFUNC(sbgemv)
5457
#else
5558
#define ERROR_NAME "SGEMM "
59+
#define GEMV BLASFUNC(sgemv)
5660
#endif
5761
#else
5862
#define SMP_THRESHOLD_MIN 8192.0
5963
#ifndef GEMM3M
6064
#ifdef XDOUBLE
6165
#define ERROR_NAME "XGEMM "
66+
#define GEMV BLASFUNC(xgemv)
6267
#elif defined(DOUBLE)
6368
#define ERROR_NAME "ZGEMM "
69+
#define GEMV BLASFUNC(zgemv)
6470
#else
6571
#define ERROR_NAME "CGEMM "
72+
#define GEMV BLASFUNC(cgemv)
6673
#endif
6774
#else
6875
#ifdef XDOUBLE
@@ -485,9 +492,38 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
485492
}
486493
#endif
487494
#endif // defined(__linux__) && defined(__x86_64__) && defined(BFLOAT16)
495+
// fprintf(stderr,"G E M M interface m n k %d %d %d\n",args.m,args.n,args.k);
488496

489497
if ((args.m == 0) || (args.n == 0)) return;
490498

499+
#if 1
500+
#ifndef GEMM3M
501+
if (args.m == 1) {
502+
char *NT=(char*)malloc(2*sizeof(char));
503+
if (transb&1)strcpy(NT,"T");
504+
else NT="N";
505+
// fprintf(stderr,"G E M V\n");
506+
GEMV(NT, &args.n ,&args.k, args.alpha, args.b, &args.ldb, args.a, &args.m, args.beta, args.c, &args.m);
507+
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
508+
return;
509+
} else {
510+
if (args.n == 1) {
511+
#ifndef CBLAS
512+
char *NT=(char*)malloc(2*sizeof(char));
513+
strcpy(NT,"N");
514+
#else
515+
char *NT=(char*)malloc(2*sizeof(char));
516+
if (transb&1)strcpy(NT,"T");
517+
else strcpy(NT,"N");
518+
#endif
519+
// fprintf(stderr,"G E M V ! ! ! lda=%d ldb=%d ldc=%d\n",args.lda,args.ldb,args.ldc);
520+
GEMV(NT, &args.m ,&args.k, args.alpha, args.a, &args.lda, args.b, &args.n, args.beta, args.c, &args.n);
521+
//SUBROUTINE SGEMV(TRANS,M,N,ALPHA,A,LDA,X,INCX,BETA,Y,INCY)
522+
return;
523+
}
524+
}
525+
#endif
526+
#endif
491527
#if 0
492528
fprintf(stderr, "m = %4d n = %d k = %d lda = %4d ldb = %4d ldc = %4d\n",
493529
args.m, args.n, args.k, args.lda, args.ldb, args.ldc);
@@ -521,10 +557,15 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
521557

522558
buffer = (XFLOAT *)blas_memory_alloc(0);
523559

524-
//For target LOONGSON3R5, applying an offset to the buffer is essential
525-
//for minimizing cache conflicts and optimizing performance.
526-
#if defined(ARCH_LOONGARCH64) && !defined(NO_AFFINITY)
527-
sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
560+
//For Loongson servers, like the 3C5000 (featuring 16 cores), applying an
561+
//offset to the buffer is essential for minimizing cache conflicts and optimizing performance.
562+
#if defined(LOONGSON3R5) && !defined(NO_AFFINITY)
563+
char model_name[128];
564+
get_cpu_model(model_name);
565+
if ((strstr(model_name, "3C5000") != NULL) || (strstr(model_name, "3D5000") != NULL))
566+
sa = (XFLOAT *)((BLASLONG)buffer + (WhereAmI() & 0xf) * GEMM_OFFSET_A);
567+
else
568+
sa = (XFLOAT *)((BLASLONG)buffer + GEMM_OFFSET_A);
528569
#else
529570
sa = (XFLOAT *)((BLASLONG)buffer +GEMM_OFFSET_A);
530571
#endif

0 commit comments

Comments
 (0)