Skip to content

Commit 7338a47

Browse files
authored
Merge pull request #5150 from Harishmcw/WoA-Experiments
Redefined threading logic for GESV and GEMV on WoA
2 parents 5f200dc + 030ae1f commit 7338a47

File tree

3 files changed

+21
-9
lines changed

3 files changed

+21
-9
lines changed

interface/gemv.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,11 @@ static inline int get_gemv_optimal_nthreads_neoversev2(BLASLONG MN, int ncpu) {
9191

9292
static inline int get_gemv_optimal_nthreads(BLASLONG MN) {
9393
int ncpu = num_cpu_avail(3);
94+
#if defined(_WIN64) && defined(_M_ARM64)
95+
if (MN > 100000000L)
96+
return num_cpu_avail(4);
97+
return 1;
98+
#endif
9499
#if defined(NEOVERSEV1) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)
95100
return get_gemv_optimal_nthreads_neoversev1(MN, ncpu);
96101
#elif defined(NEOVERSEV2) && !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16)

interface/lapack/gesv.c

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -117,13 +117,15 @@ int NAME(blasint *N, blasint *NRHS, FLOAT *a, blasint *ldA, blasint *ipiv,
117117

118118
#if defined(_WIN64) && defined(_M_ARM64)
119119
#ifdef COMPLEX
120-
if (args.m * args.n > 600)
120+
if (args.m * args.n <= 300)
121121
#else
122-
if (args.m * args.n > 1000)
122+
if (args.m * args.n <= 500)
123123
#endif
124-
args.nthreads = num_cpu_avail(4);
125-
else
126124
args.nthreads = 1;
125+
else if (args.m * args.n <= 1000)
126+
args.nthreads = 4;
127+
else
128+
args.nthreads = num_cpu_avail(4);
127129
#else
128130
#ifndef DOUBLE
129131
if (args.m * args.n < 40000)

interface/zgemv.c

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,25 +252,30 @@ void CNAME(enum CBLAS_ORDER order,
252252

253253
#ifdef SMP
254254

255-
if ( 1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD )
255+
#if defined(_WIN64) && defined(_M_ARM64)
256+
if (m*n > 25000000L)
257+
nthreads = num_cpu_avail(4);
258+
else
259+
nthreads = 1;
260+
#else
261+
if (1L * m * n < 1024L * GEMM_MULTITHREAD_THRESHOLD)
256262
nthreads = 1;
257263
else
258264
nthreads = num_cpu_avail(2);
265+
#endif
259266

260267
if (nthreads == 1) {
261-
#endif
268+
#endif
262269

263270
(gemv[(int)trans])(m, n, 0, alpha_r, alpha_i, a, lda, x, incx, y, incy, buffer);
264271

265272
#ifdef SMP
266-
267273
} else {
268-
269274
(gemv_thread[(int)trans])(m, n, ALPHA, a, lda, x, incx, y, incy, buffer, nthreads);
270-
271275
}
272276
#endif
273277

278+
274279
STACK_FREE(buffer);
275280

276281
FUNCTION_PROFILE_END(4, m * n + m + n, 2 * m * n);

0 commit comments

Comments
 (0)