Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 451794a

Browse files
committed
Clean-up init code.
1 parent 8d0554c commit 451794a

File tree

2 files changed

+107
-99
lines changed

2 files changed

+107
-99
lines changed

src/CuArrays.jl

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,41 @@ using Libdl
1616
using Requires
1717

1818

19+
## deferred initialization
20+
21+
# CUDA packages require complex initialization (discover CUDA, download artifacts, etc)
22+
# that can't happen at module load time, so defer that to run time upon actual use.
23+
24+
const configured = Ref{Union{Nothing,Bool}}(nothing)
25+
26+
"""
27+
functional(show_reason=false)
28+
29+
Check if the package has been configured successfully and is ready to use.
30+
31+
This call is intended for packages that support conditionally using an available GPU. If you
32+
fail to check whether CUDA is functional, actual use of functionality might warn and error.
33+
"""
34+
function functional(show_reason::Bool=false)
35+
if configured[] === nothing
36+
configured[] = false
37+
if __configure__(show_reason)
38+
configured[] = true
39+
__runtime_init__()
40+
end
41+
end
42+
configured[]
43+
end
44+
45+
# macro to guard code that only can run after the package has successfully initialized
46+
macro after_init(ex)
47+
quote
48+
@assert functional(true) "CuArrays.jl did not successfully initialize, and is not usable."
49+
$(esc(ex))
50+
end
51+
end
52+
53+
1954
## source code includes
2055

2156
include("bindeps.jl")
@@ -54,10 +89,46 @@ function __init__()
5489
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
5590

5691
__init_memory__()
92+
end
93+
94+
function __configure__(show_reason::Bool)
95+
# if any dependent GPU package failed, expect it to have logged an error and bail out
96+
if !CUDAdrv.functional(show_reason) || !CUDAnative.functional(show_reason)
97+
show_reason && @warn "CuArrays.jl did not initialize because CUDAdrv.jl or CUDAnative.jl failed to"
98+
return
99+
end
100+
101+
return __configure_dependencies__(show_reason)
102+
end
57103

58-
# NOTE: we only perform minimal initialization here that does not require CUDA or a GPU.
59-
# most of the actual initialization is deferred to run time:
60-
# see bindeps.jl for initialization of CUDA binary dependencies.
104+
function __runtime_init__()
105+
cuda = version()
106+
107+
if has_cutensor()
108+
cutensor = CUTENSOR.version()
109+
if cutensor < v"1"
110+
@warn("CuArrays.jl only supports CUTENSOR 1.0 or higher")
111+
end
112+
113+
cutensor_cuda = CUTENSOR.cuda_version()
114+
if cutensor_cuda.major != cuda.major || cutensor_cuda.minor != cuda.minor
115+
@warn("You are using CUTENSOR $cutensor for CUDA $cutensor_cuda with CUDA toolkit $cuda; these might be incompatible.")
116+
end
117+
end
118+
119+
if has_cudnn()
120+
cudnn = CUDNN.version()
121+
if cudnn < v"7.6"
122+
@warn("CuArrays.jl only supports CUDNN v7.6 or higher")
123+
end
124+
125+
cudnn_cuda = CUDNN.cuda_version()
126+
if cudnn_cuda.major != cuda.major || cudnn_cuda.minor != cuda.minor
127+
@warn("You are using CUDNN $cudnn for CUDA $cudnn_cuda with CUDA toolkit $cuda; these might be incompatible.")
128+
end
129+
end
130+
131+
return
61132
end
62133

63134
end # module

src/bindeps.jl

Lines changed: 33 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,20 @@ using Libdl
88

99
const __version = Ref{VersionNumber}()
1010

11+
"""
12+
version()
13+
14+
Returns the version of the CUDA toolkit in use.
15+
"""
16+
version() = @after_init(__version[])
17+
18+
"""
19+
release()
20+
21+
Returns the CUDA release part of the version as returned by [`version`](@ref).
22+
"""
23+
release() = @after_init(VersionNumber(__version[].major, __version[].minor))
24+
1125
const __libcublas = Ref{String}()
1226
const __libcusparse = Ref{String}()
1327
const __libcusolver = Ref{String}()
@@ -16,6 +30,18 @@ const __libcurand = Ref{String}()
1630
const __libcudnn = Ref{Union{Nothing,String}}(nothing)
1731
const __libcutensor = Ref{Union{Nothing,String}}(nothing)
1832

33+
libcublas() = @after_init(__libcublas[])
34+
libcusparse() = @after_init(__libcusparse[])
35+
libcusolver() = @after_init(__libcusolver[])
36+
libcufft() = @after_init(__libcufft[])
37+
libcurand() = @after_init(__libcurand[])
38+
libcudnn() = @after_init(__libcudnn[])
39+
libcutensor() = @after_init(__libcutensor[])
40+
41+
export has_cudnn, has_cutensor
42+
has_cudnn() = libcudnn() !== nothing
43+
has_cutensor() = libcutensor() !== nothing
44+
1945

2046
## discovery
2147

@@ -186,109 +212,20 @@ function use_local_cutensor(cuda_dirs)
186212
return true
187213
end
188214

189-
190-
## initialization
191-
192-
const __initialized__ = Ref{Union{Nothing,Bool}}(nothing)
193-
194-
"""
195-
functional(show_reason=false)
196-
197-
Check if the package has been initialized successfully and is ready to use.
198-
199-
This call is intended for packages that support conditionally using an available GPU. If you
200-
fail to check whether CUDA is functional, actual use of functionality might warn and error.
201-
"""
202-
function functional(show_reason::Bool=false)
203-
if __initialized__[] === nothing
204-
__runtime_init__(show_reason)
205-
end
206-
__initialized__[]
207-
end
208-
209-
function __runtime_init__(show_reason::Bool)
210-
__initialized__[] = false
211-
212-
# if any dependent GPU package failed, expect it to have logged an error and bail out
213-
if !CUDAdrv.functional(show_reason) || !CUDAnative.functional(show_reason)
214-
show_reason && @warn "CuArrays.jl did not initialize because CUDAdrv.jl or CUDAnative.jl failed to"
215-
return
216-
end
217-
218-
219-
# CUDA toolkit
215+
function __configure_dependencies__(show_reason::Bool)
216+
found = false
220217

221218
if parse(Bool, get(ENV, "JULIA_CUDA_USE_BINARYBUILDER", "true"))
222-
__initialized__[] = use_artifact_cuda()
219+
found = use_artifact_cuda()
223220
end
224221

225-
if !__initialized__[]
226-
__initialized__[] = use_local_cuda()
222+
if !found
223+
found = use_local_cuda()
227224
end
228225

229-
if !__initialized__[]
226+
if !found
230227
show_reason && @error "Could not find a suitable CUDA installation"
231-
return
232228
end
233229

234-
# library compatibility
235-
cuda = version()
236-
if has_cutensor()
237-
cutensor = CUTENSOR.version()
238-
if cutensor < v"1"
239-
@warn("CuArrays.jl only supports CUTENSOR 1.0 or higher")
240-
end
241-
242-
cutensor_cuda = CUTENSOR.cuda_version()
243-
if cutensor_cuda.major != cuda.major || cutensor_cuda.minor != cuda.minor
244-
@warn("You are using CUTENSOR $cutensor for CUDA $cutensor_cuda with CUDA toolkit $cuda; these might be incompatible.")
245-
end
246-
end
247-
if has_cudnn()
248-
cudnn = CUDNN.version()
249-
if cudnn < v"7.6"
250-
@warn("CuArrays.jl only supports CUDNN v7.6 or higher")
251-
end
252-
253-
cudnn_cuda = CUDNN.cuda_version()
254-
if cudnn_cuda.major != cuda.major || cudnn_cuda.minor != cuda.minor
255-
@warn("You are using CUDNN $cudnn for CUDA $cudnn_cuda with CUDA toolkit $cuda; these might be incompatible.")
256-
end
257-
end
230+
return found
258231
end
259-
260-
261-
## getters
262-
263-
macro initialized(ex)
264-
quote
265-
@assert functional(true) "CuArrays.jl is not functional"
266-
$(esc(ex))
267-
end
268-
end
269-
270-
"""
271-
version()
272-
273-
Returns the version of the CUDA toolkit in use.
274-
"""
275-
version() = @initialized(__version[])
276-
277-
"""
278-
release()
279-
280-
Returns the CUDA release part of the version as returned by [`version`](@ref).
281-
"""
282-
release() = @initialized(VersionNumber(__version[].major, __version[].minor))
283-
284-
libcublas() = @initialized(__libcublas[])
285-
libcusparse() = @initialized(__libcusparse[])
286-
libcusolver() = @initialized(__libcusolver[])
287-
libcufft() = @initialized(__libcufft[])
288-
libcurand() = @initialized(__libcurand[])
289-
libcudnn() = @initialized(__libcudnn[])
290-
libcutensor() = @initialized(__libcutensor[])
291-
292-
export has_cudnn, has_cutensor
293-
has_cudnn() = libcudnn() !== nothing
294-
has_cutensor() = libcutensor() !== nothing

0 commit comments

Comments
 (0)