Skip to content

Commit b8110f8

Browse files
jw3126OkonSamueltkf
authored
add BLAS.get_num_threads (#36360)
Co-authored-by: Okon Samuel <39421418+OkonSamuel@users.noreply.github.com> Co-authored-by: Takafumi Arakaki <aka.tkf@gmail.com>
1 parent 39c278b commit b8110f8

File tree

3 files changed

+115
-13
lines changed

3 files changed

+115
-13
lines changed

NEWS.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Standard library changes
6161
#### LinearAlgebra
6262
* New method `LinearAlgebra.issuccess(::CholeskyPivoted)` for checking whether pivoted Cholesky factorization was successful ([#36002]).
6363
* `UniformScaling` can now be indexed into using ranges to return dense matrices and vectors ([#24359]).
64+
* New function `LinearAlgebra.BLAS.get_num_threads()` for getting the number of BLAS threads. ([#36360])
6465

6566
#### Markdown
6667

stdlib/LinearAlgebra/src/blas.jl

Lines changed: 87 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -106,27 +106,101 @@ end
106106

107107
openblas_get_config() = strip(unsafe_string(ccall((@blasfunc(openblas_get_config), libblas), Ptr{UInt8}, () )))
108108

109+
function guess_vendor()
110+
# like determine_vendor, but guesses blas in some cases
111+
# where determine_vendor returns :unknown
112+
ret = vendor()
113+
if Sys.isapple() && (ret == :unknown)
114+
ret = :osxblas
115+
end
116+
ret
117+
end
118+
119+
109120
"""
110-
set_num_threads(n)
121+
set_num_threads(n::Integer)
122+
set_num_threads(::Nothing)
111123
112-
Set the number of threads the BLAS library should use.
124+
Set the number of threads the BLAS library should use equal to `n::Integer`.
125+
126+
Also accepts `nothing`, in which case julia tries to guess the default number of threads.
127+
Passing `nothing` is discouraged and mainly exists for the following reason:
128+
129+
On exotic variants of BLAS, `nothing` may be returned by `get_num_threads()`.
130+
Thus on exotic variants of BLAS, the following pattern may fail to set the number of threads:
131+
132+
```julia
133+
old = get_num_threads()
134+
set_num_threads(1)
135+
@threads for i in 1:10
136+
# single-threaded BLAS calls
137+
end
138+
set_num_threads(old)
139+
```
140+
Because `set_num_threads` accepts `nothing`, this code can still run
141+
on exotic variants of BLAS without error. Warnings will be raised instead.
142+
143+
!!! compat "Julia 1.6"
144+
`set_num_threads(::Nothing)` requires at least Julia 1.6.
113145
"""
114-
function set_num_threads(n::Integer)
115-
blas = vendor()
116-
if blas === :openblas
117-
return ccall((:openblas_set_num_threads, libblas), Cvoid, (Int32,), n)
118-
elseif blas === :openblas64
119-
return ccall((:openblas_set_num_threads64_, libblas), Cvoid, (Int32,), n)
120-
elseif blas === :mkl
146+
set_num_threads(n)::Nothing = _set_num_threads(n)
147+
148+
function _set_num_threads(n::Integer; _blas = guess_vendor())
149+
if _blas === :openblas || _blas == :openblas64
150+
return ccall((@blasfunc(openblas_set_num_threads), libblas), Cvoid, (Cint,), n)
151+
elseif _blas === :mkl
121152
# MKL may let us set the number of threads in several ways
122153
return ccall((:MKL_Set_Num_Threads, libblas), Cvoid, (Cint,), n)
123-
end
124-
125-
# OSX BLAS looks at an environment variable
126-
@static if Sys.isapple()
154+
elseif _blas === :osxblas
155+
# OSX BLAS looks at an environment variable
127156
ENV["VECLIB_MAXIMUM_THREADS"] = n
157+
else
158+
@assert _blas === :unknown
159+
@warn "Failed to set number of BLAS threads." maxlog=1
128160
end
161+
return nothing
162+
end
163+
164+
_tryparse_env_int(key) = tryparse(Int, get(ENV, key, ""))
165+
166+
function _set_num_threads(::Nothing; _blas = guess_vendor())
167+
n = something(
168+
_tryparse_env_int("OPENBLAS_NUM_THREADS"),
169+
_tryparse_env_int("OMP_NUM_THREADS"),
170+
max(1, Sys.CPU_THREADS ÷ 2),
171+
)
172+
_set_num_threads(n; _blas)
173+
end
174+
175+
"""
176+
get_num_threads()
129177
178+
Get the number of threads the BLAS library is using.
179+
180+
On exotic variants of `BLAS` this function can fail, which is indicated by returning `nothing`.
181+
182+
!!! compat "Julia 1.6"
183+
`get_num_threads` requires at least Julia 1.6.
184+
"""
185+
get_num_threads(;_blas=guess_vendor())::Union{Int, Nothing} = _get_num_threads()
186+
187+
function _get_num_threads(; _blas = guess_vendor())::Union{Int, Nothing}
188+
if _blas === :openblas || _blas === :openblas64
189+
return Int(ccall((@blasfunc(openblas_get_num_threads), libblas), Cint, ()))
190+
elseif _blas === :mkl
191+
return Int(ccall((:mkl_get_max_threads, libblas), Cint, ()))
192+
elseif _blas === :osxblas
193+
key = "VECLIB_MAXIMUM_THREADS"
194+
nt = _tryparse_env_int(key)
195+
if nt === nothing
196+
@warn "Failed to read environment variable $key" maxlog=1
197+
else
198+
return nt
199+
end
200+
else
201+
@assert _blas === :unknown
202+
end
203+
@warn "Could not get number of BLAS threads. Returning `nothing` instead." maxlog=1
130204
return nothing
131205
end
132206

stdlib/LinearAlgebra/test/blas.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -553,4 +553,31 @@ Base.stride(A::WrappedArray, i::Int) = stride(A.A, i)
553553
end
554554
end
555555

556+
@testset "get_set_num_threads" begin
557+
default = BLAS.get_num_threads()
558+
@test default isa Int
559+
@test default > 0
560+
BLAS.set_num_threads(1)
561+
@test BLAS.get_num_threads() === 1
562+
BLAS.set_num_threads(default)
563+
@test BLAS.get_num_threads() === default
564+
565+
@test_logs (:warn,) match_mode=:any BLAS._set_num_threads(1, _blas=:unknown)
566+
if BLAS.guess_vendor() !== :osxblas
567+
# test osxblas which is not covered by CI
568+
withenv("VECLIB_MAXIMUM_THREADS" => nothing) do
569+
@test @test_logs(
570+
(:warn,),
571+
(:warn,),
572+
match_mode=:any,
573+
BLAS._get_num_threads(_blas=:osxblas),
574+
) === nothing
575+
@test_logs BLAS._set_num_threads(1, _blas=:osxblas)
576+
@test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 1
577+
@test_logs BLAS._set_num_threads(2, _blas=:osxblas)
578+
@test @test_logs(BLAS._get_num_threads(_blas=:osxblas)) === 2
579+
end
580+
end
581+
end
582+
556583
end # module TestBLAS

0 commit comments

Comments
 (0)