@@ -106,27 +106,101 @@ end
106
106
107
107
openblas_get_config () = strip (unsafe_string (ccall ((@blasfunc (openblas_get_config), libblas), Ptr{UInt8}, () )))
108
108
109
+ function guess_vendor ()
110
+ # like determine_vendor, but guesses blas in some cases
111
+ # where determine_vendor returns :unknown
112
+ ret = vendor ()
113
+ if Sys. isapple () && (ret == :unknown )
114
+ ret = :osxblas
115
+ end
116
+ ret
117
+ end
118
+
119
+
109
120
"""
110
- set_num_threads(n)
121
+ set_num_threads(n::Integer)
122
+ set_num_threads(::Nothing)
111
123
112
- Set the number of threads the BLAS library should use.
124
+ Set the number of threads the BLAS library should use equal to `n::Integer`.
125
+
126
+ Also accepts `nothing`, in which case julia tries to guess the default number of threads.
127
+ Passing `nothing` is discouraged and mainly exists for the following reason:
128
+
129
+ On exotic variants of BLAS, `nothing` may be returned by `get_num_threads()`.
130
+ Thus on exotic variants of BLAS, the following pattern may fail to set the number of threads:
131
+
132
+ ```julia
133
+ old = get_num_threads()
134
+ set_num_threads(1)
135
+ @threads for i in 1:10
136
+ # single-threaded BLAS calls
137
+ end
138
+ set_num_threads(old)
139
+ ```
140
+ Because `set_num_threads` accepts `nothing`, this code can still run
141
+ on exotic variants of BLAS without error. Warnings will be raised instead.
142
+
143
+ !!! compat "Julia 1.6"
144
+ `set_num_threads(::Nothing)` requires at least Julia 1.6.
113
145
"""
114
- function set_num_threads (n:: Integer )
115
- blas = vendor ()
116
- if blas === :openblas
117
- return ccall ((:openblas_set_num_threads , libblas), Cvoid, (Int32,), n)
118
- elseif blas === :openblas64
119
- return ccall ((:openblas_set_num_threads64_ , libblas), Cvoid, (Int32,), n)
120
- elseif blas === :mkl
146
+ set_num_threads (n):: Nothing = _set_num_threads (n)
147
+
148
+ function _set_num_threads (n:: Integer ; _blas = guess_vendor ())
149
+ if _blas === :openblas || _blas == :openblas64
150
+ return ccall ((@blasfunc (openblas_set_num_threads), libblas), Cvoid, (Cint,), n)
151
+ elseif _blas === :mkl
121
152
# MKL may let us set the number of threads in several ways
122
153
return ccall ((:MKL_Set_Num_Threads , libblas), Cvoid, (Cint,), n)
123
- end
124
-
125
- # OSX BLAS looks at an environment variable
126
- @static if Sys. isapple ()
154
+ elseif _blas === :osxblas
155
+ # OSX BLAS looks at an environment variable
127
156
ENV [" VECLIB_MAXIMUM_THREADS" ] = n
157
+ else
158
+ @assert _blas === :unknown
159
+ @warn " Failed to set number of BLAS threads." maxlog= 1
128
160
end
161
+ return nothing
162
+ end
163
+
164
+ _tryparse_env_int (key) = tryparse (Int, get (ENV , key, " " ))
165
+
166
+ function _set_num_threads (:: Nothing ; _blas = guess_vendor ())
167
+ n = something (
168
+ _tryparse_env_int (" OPENBLAS_NUM_THREADS" ),
169
+ _tryparse_env_int (" OMP_NUM_THREADS" ),
170
+ max (1 , Sys. CPU_THREADS ÷ 2 ),
171
+ )
172
+ _set_num_threads (n; _blas)
173
+ end
174
+
175
+ """
176
+ get_num_threads()
129
177
178
+ Get the number of threads the BLAS library is using.
179
+
180
+ On exotic variants of `BLAS` this function can fail, which is indicated by returning `nothing`.
181
+
182
+ !!! compat "Julia 1.6"
183
+ `get_num_threads` requires at least Julia 1.6.
184
+ """
185
+ get_num_threads (;_blas= guess_vendor ()):: Union{Int, Nothing} = _get_num_threads ()
186
+
187
+ function _get_num_threads (; _blas = guess_vendor ()):: Union{Int, Nothing}
188
+ if _blas === :openblas || _blas === :openblas64
189
+ return Int (ccall ((@blasfunc (openblas_get_num_threads), libblas), Cint, ()))
190
+ elseif _blas === :mkl
191
+ return Int (ccall ((:mkl_get_max_threads , libblas), Cint, ()))
192
+ elseif _blas === :osxblas
193
+ key = " VECLIB_MAXIMUM_THREADS"
194
+ nt = _tryparse_env_int (key)
195
+ if nt === nothing
196
+ @warn " Failed to read environment variable $key " maxlog= 1
197
+ else
198
+ return nt
199
+ end
200
+ else
201
+ @assert _blas === :unknown
202
+ end
203
+ @warn " Could not get number of BLAS threads. Returning `nothing` instead." maxlog= 1
130
204
return nothing
131
205
end
132
206
0 commit comments