Skip to content

Commit a54f9a9

Browse files
authored
Merge pull request #5071 from annop-w/sgemm_throttling
Add thread throttling profile for SGEMM on NEOVERSEV1
2 parents 9f2319b + c8cd8da commit a54f9a9

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

CONTRIBUTORS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,3 +232,6 @@ In chronological order:
232232

233233
* Aniket P. Garade <https://github.com/garadeaniket> Sushil Pratap Singh <https://github.com/SushilPratap04> Juliya James <https://github.com/Juliya32>
234234
* [2024-12-13] Optimized swap and rot Level-1 BLAS routines with ARM SVE
235+
236+
* Annop Wongwathanarat <annop.wongwathanarat@arm.com>
237+
* [2025-01-10] Add thread throttling profile for SGEMM on NEOVERSEV1

interface/gemm.c

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*********************************************************************/
2-
/* Copyright 2024 The OpenBLAS Project */
2+
/* Copyright 2024, 2025 The OpenBLAS Project */
33
/* Copyright 2009, 2010 The University of Texas at Austin. */
44
/* All rights reserved. */
55
/* */
@@ -177,6 +177,49 @@ static int init_amxtile_permission() {
177177
}
178178
#endif
179179

180+
#ifdef DYNAMIC_ARCH
181+
extern char* gotoblas_corename(void);
182+
#endif
183+
184+
#if defined(DYNAMIC_ARCH) || defined(NEOVERSEV1)
185+
static inline int get_gemm_optimal_nthreads_neoversev1(double MNK, int ncpu) {
186+
return
187+
MNK < 262144L ? 1
188+
: MNK < 1124864L ? MIN(ncpu, 6)
189+
: MNK < 7880599L ? MIN(ncpu, 12)
190+
: MNK < 17173512L ? MIN(ncpu, 16)
191+
: MNK < 33386248L ? MIN(ncpu, 20)
192+
: MNK < 57066625L ? MIN(ncpu, 24)
193+
: MNK < 91733851L ? MIN(ncpu, 32)
194+
: MNK < 265847707L ? MIN(ncpu, 40)
195+
: MNK < 458314011L ? MIN(ncpu, 48)
196+
: MNK < 729000000L ? MIN(ncpu, 56)
197+
: ncpu;
198+
}
199+
#endif
200+
201+
static inline int get_gemm_optimal_nthreads(double MNK) {
202+
int ncpu = num_cpu_avail(3);
203+
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
204+
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
205+
#elif defined(DYNAMIC_ARCH) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
206+
if (strcmp(gotoblas_corename(), "neoversev1") == 0) {
207+
return get_gemm_optimal_nthreads_neoversev1(MNK, ncpu);
208+
}
209+
#endif
210+
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) ) {
211+
return 1;
212+
}
213+
else {
214+
if (MNK/ncpu < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD) {
215+
return MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD);
216+
}
217+
else {
218+
return ncpu;
219+
}
220+
}
221+
}
222+
180223
#ifndef CBLAS
181224

182225
void NAME(char *TRANSA, char *TRANSB,
@@ -310,7 +353,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
310353
FLOAT *beta = (FLOAT*) vbeta;
311354
FLOAT *a = (FLOAT*) va;
312355
FLOAT *b = (FLOAT*) vb;
313-
FLOAT *c = (FLOAT*) vc;
356+
FLOAT *c = (FLOAT*) vc;
314357
#endif
315358

316359
blas_arg_t args;
@@ -352,7 +395,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
352395
#if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT)
353396
#ifdef DYNAMIC_ARCH
354397
if (support_avx512() )
355-
#endif
398+
#endif
356399
if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) {
357400
SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc);
358401
return;
@@ -604,13 +647,7 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS
604647
#endif
605648

606649
MNK = (double) args.m * (double) args.n * (double) args.k;
607-
if ( MNK <= (SMP_THRESHOLD_MIN * (double) GEMM_MULTITHREAD_THRESHOLD) )
608-
args.nthreads = 1;
609-
else {
610-
args.nthreads = num_cpu_avail(3);
611-
if (MNK/args.nthreads < SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD)
612-
args.nthreads = MNK/(SMP_THRESHOLD_MIN*(double)GEMM_MULTITHREAD_THRESHOLD);
613-
}
650+
args.nthreads = get_gemm_optimal_nthreads(MNK);
614651

615652
args.common = NULL;
616653

0 commit comments

Comments
 (0)