Skip to content

Commit a782103

Browse files
authored
Merge pull request #4425 from martin-frbg/issue2392
Add BLAS extension openblas_set_num_threads_local()
2 parents b334152 + 152a6c4 commit a782103

File tree

6 files changed

+23
-3
lines changed

6 files changed

+23
-3
lines changed

cblas.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ extern "C" {
1212
/*Set the number of threads on runtime.*/
1313
void openblas_set_num_threads(int num_threads);
1414
void goto_set_num_threads(int num_threads);
15+
int openblas_set_num_threads_local(int num_threads);
1516

1617
/*Get the number of threads on runtime.*/
1718
int openblas_get_num_threads(void);

common_thread.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,19 +137,20 @@ typedef struct blas_queue {
137137

138138
extern int blas_server_avail;
139139
extern int blas_omp_number_max;
140+
extern int blas_omp_threads_local;
140141

141142
static __inline int num_cpu_avail(int level) {
142143

143144
#ifdef USE_OPENMP
144145
int openmp_nthreads;
145146
openmp_nthreads=omp_get_max_threads();
147+
if (omp_in_parallel()) openmp_nthreads = blas_omp_threads_local;
146148
#endif
147149

148150
#ifndef USE_OPENMP
149151
if (blas_cpu_number == 1
150-
#endif
151-
#ifdef USE_OPENMP
152-
if (openmp_nthreads == 1 || omp_in_parallel()
152+
#else
153+
if (openmp_nthreads == 1
153154
#endif
154155
) return 1;
155156

driver/others/blas_server.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,8 @@ extern unsigned int openblas_thread_timeout(void);
113113
/* We need this global for checking if initialization is finished. */
114114
int blas_server_avail __attribute__((aligned(ATTRIBUTE_SIZE))) = 0;
115115

116+
int blas_omp_threads_local = 1;
117+
116118
/* Local Variables */
117119
#if defined(USE_PTHREAD_LOCK)
118120
static pthread_mutex_t server_lock = PTHREAD_MUTEX_INITIALIZER;

driver/others/blas_server_omp.c

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969

7070
int blas_server_avail = 0;
7171
int blas_omp_number_max = 0;
72+
int blas_omp_threads_local = 1;
7273

7374
extern int openblas_omp_adaptive_env(void);
7475

driver/others/blas_server_win32.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@ static CRITICAL_SECTION queue_lock;
5959
/* We need this global for checking if initialization is finished. */
6060
int blas_server_avail = 0;
6161

62+
int blas_omp_threads_local = 1;
63+
6264
/* Local Variables */
6365
static BLASULONG server_lock = 0;
6466

driver/others/openblas_set_num_threads.c

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,20 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3636
#ifdef SMP_SERVER
3737

3838
extern void openblas_set_num_threads(int num_threads) ;
39+
extern int openblas_get_num_threads(void) ;
3940

4041
void openblas_set_num_threads_(int* num_threads){
4142
openblas_set_num_threads(*num_threads);
4243
}
4344

45+
int openblas_set_num_threads_local(int num_threads){
46+
int ret = openblas_get_num_threads();
47+
openblas_set_num_threads(num_threads);
48+
blas_omp_threads_local=num_threads;
49+
return ret;
50+
}
51+
52+
4453
#else
4554
//Single thread
4655

@@ -50,4 +59,8 @@ void openblas_set_num_threads(int num_threads) {
5059
void openblas_set_num_threads_(int* num_threads){
5160

5261
}
62+
63+
int openblas_set_num_threads_local(int num_threads){
64+
return 1;
65+
}
5366
#endif

0 commit comments

Comments
 (0)