Skip to content

Commit dc8a11e

Browse files
authored
Merge pull request #1491 from JuliaGPU/tb/cudnn_hotfix
Limit time held by CUDNN locks.
2 parents 554dcc4 + 92873e4 commit dc8a11e

File tree

6 files changed

+79
-84
lines changed

6 files changed

+79
-84
lines changed

lib/cudadrv/memory.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -663,7 +663,7 @@ function __pin(ptr::Ptr{Nothing}, sz::Int)
663663
ctx = context()
664664
key = (ctx,ptr)
665665

666-
@lock __pin_lock begin
666+
Base.@lock __pin_lock begin
667667
pin_count = if haskey(__pin_count, key)
668668
__pin_count[key] += 1
669669
else
@@ -687,7 +687,7 @@ end
687687
function __unpin(ptr::Ptr{Nothing}, ctx::CuContext)
688688
key = (ctx,ptr)
689689

690-
@spinlock __pin_lock begin
690+
Base.@lock __pin_lock begin
691691
@assert haskey(__pin_count, key) "Cannot unpin unmanaged pointer $ptr."
692692
pin_count = __pin_count[key] -= 1
693693
@assert pin_count >= 0 "Double unpin for $ptr"

lib/cudnn/CUDNN.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ function log_message(sev, udata, dbg_ptr, ptr)
121121
str = unsafe_string(ptr, len) # XXX: can this yield?
122122

123123
# print asynchronously
124-
@spinlock log_lock begin
124+
Base.@lock log_lock begin
125125
push!(log_messages, (; sev, dbg, str))
126126
end
127127
ccall(:uv_async_send, Cint, (Ptr{Cvoid},), udata)
@@ -153,7 +153,7 @@ function __runtime_init__()
153153
if (isdebug(:init, CUDNN) || Base.JLOptions().debug_level >= 2) &&
154154
version() >= v"8.2" # NVIDIA bug #3256123
155155
log_cond[] = Base.AsyncCondition() do async_cond
156-
message = @lock log_lock popfirst!(log_messages)
156+
message = Base.@lock log_lock popfirst!(log_messages)
157157
_log_message(message...)
158158
end
159159

lib/cudnn/convolution.jl

Lines changed: 48 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -172,52 +172,70 @@ end
172172
const cudnnConvolutionFwdAlgoPerfCache = Dict{Tuple,cudnnConvolutionFwdAlgoPerf_t}()
173173
const cudnnConvolutionFwdAlgoPerfCacheLock = ReentrantLock()
174174
function cudnnConvolutionFwdAlgoPerf(xDesc, x, wDesc, w, convDesc, yDesc, y, biasDesc, activation)
175-
lock(cudnnConvolutionFwdAlgoPerfCacheLock) do
176-
get!(cudnnConvolutionFwdAlgoPerfCache, (xDesc, wDesc, convDesc, biasDesc, activation)) do
177-
requestedAlgoCount = Int(CUDNN_CONVOLUTION_FWD_ALGO_COUNT)
178-
returnedAlgoCount = Cint[0]
179-
perfResults = Array{cudnnConvolutionFwdAlgoPerf_t}(undef,requestedAlgoCount)
180-
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(x)
181-
with_workspace(workspaceSize) do workspace
182-
cudnnFindConvolutionForwardAlgorithmEx(handle(),xDesc,x,wDesc,w,convDesc,yDesc,y,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
183-
end
184-
cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
175+
key = (xDesc, wDesc, convDesc, biasDesc, activation)
176+
val = lock(cudnnConvolutionFwdAlgoPerfCacheLock) do
177+
get(cudnnConvolutionFwdAlgoPerfCache, key, nothing)
178+
end
179+
if val === nothing
180+
requestedAlgoCount = Int(CUDNN_CONVOLUTION_FWD_ALGO_COUNT)
181+
returnedAlgoCount = Cint[0]
182+
perfResults = Array{cudnnConvolutionFwdAlgoPerf_t}(undef,requestedAlgoCount)
183+
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(x)
184+
with_workspace(workspaceSize) do workspace
185+
cudnnFindConvolutionForwardAlgorithmEx(handle(),xDesc,x,wDesc,w,convDesc,yDesc,y,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
186+
end
187+
val = cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
188+
lock(cudnnConvolutionFwdAlgoPerfCacheLock) do
189+
cudnnConvolutionFwdAlgoPerfCache[key] = val
185190
end
186191
end
192+
return val
187193
end
188194

189195
const cudnnConvolutionBwdDataAlgoPerfCache = Dict{Tuple,cudnnConvolutionBwdDataAlgoPerf_t}()
190196
const cudnnConvolutionBwdDataAlgoPerfCacheLock = ReentrantLock()
191197
function cudnnConvolutionBwdDataAlgoPerf(wDesc, w, dyDesc, dy, convDesc, dxDesc, dx)
192-
lock(cudnnConvolutionBwdDataAlgoPerfCacheLock) do
193-
get!(cudnnConvolutionBwdDataAlgoPerfCache, (wDesc, dyDesc, convDesc)) do
194-
requestedAlgoCount = Int(CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT)
195-
returnedAlgoCount = Cint[0]
196-
perfResults = Array{cudnnConvolutionBwdDataAlgoPerf_t}(undef,requestedAlgoCount)
197-
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(dx)
198-
with_workspace(workspaceSize) do workspace
199-
cudnnFindConvolutionBackwardDataAlgorithmEx(handle(),wDesc,w,dyDesc,dy,convDesc,dxDesc,dx,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
200-
end
201-
cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
198+
key = (wDesc, dyDesc, convDesc)
199+
val = lock(cudnnConvolutionBwdDataAlgoPerfCacheLock) do
200+
get(cudnnConvolutionBwdDataAlgoPerfCache, key, nothing)
201+
end
202+
if val === nothing
203+
requestedAlgoCount = Int(CUDNN_CONVOLUTION_BWD_DATA_ALGO_COUNT)
204+
returnedAlgoCount = Cint[0]
205+
perfResults = Array{cudnnConvolutionBwdDataAlgoPerf_t}(undef,requestedAlgoCount)
206+
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(dx)
207+
with_workspace(workspaceSize) do workspace
208+
cudnnFindConvolutionBackwardDataAlgorithmEx(handle(),wDesc,w,dyDesc,dy,convDesc,dxDesc,dx,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
209+
end
210+
val = cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
211+
lock(cudnnConvolutionBwdDataAlgoPerfCacheLock) do
212+
cudnnConvolutionBwdDataAlgoPerfCache[key] = val
202213
end
203214
end
215+
val
204216
end
205217

206218
const cudnnConvolutionBwdFilterAlgoPerfCache = Dict{Tuple,cudnnConvolutionBwdFilterAlgoPerf_t}()
207219
const cudnnConvolutionBwdFilterAlgoPerfCacheLock = ReentrantLock()
208220
function cudnnConvolutionBwdFilterAlgoPerf(xDesc, x, dyDesc, dy, convDesc, dwDesc, dw)
209-
lock(cudnnConvolutionBwdFilterAlgoPerfCacheLock) do
210-
get!(cudnnConvolutionBwdFilterAlgoPerfCache, (xDesc, dyDesc, convDesc)) do
211-
requestedAlgoCount = Int(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT)
212-
returnedAlgoCount = Cint[0]
213-
perfResults = Array{cudnnConvolutionBwdFilterAlgoPerf_t}(undef,requestedAlgoCount)
214-
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(x)
215-
with_workspace(workspaceSize) do workspace
216-
cudnnFindConvolutionBackwardFilterAlgorithmEx(handle(),xDesc,x,dyDesc,dy,convDesc,dwDesc,dw,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
217-
end
218-
cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
221+
key = (xDesc, dyDesc, convDesc)
222+
val = lock(cudnnConvolutionBwdFilterAlgoPerfCacheLock) do
223+
get(cudnnConvolutionBwdFilterAlgoPerfCache, (xDesc, dyDesc, convDesc), nothing)
224+
end
225+
if val === nothing
226+
requestedAlgoCount = Int(CUDNN_CONVOLUTION_BWD_FILTER_ALGO_COUNT)
227+
returnedAlgoCount = Cint[0]
228+
perfResults = Array{cudnnConvolutionBwdFilterAlgoPerf_t}(undef,requestedAlgoCount)
229+
workspaceSize() = cudnnFindConvolutionAlgorithmWorkspaceSize(x)
230+
with_workspace(workspaceSize) do workspace
231+
cudnnFindConvolutionBackwardFilterAlgorithmEx(handle(),xDesc,x,dyDesc,dy,convDesc,dwDesc,dw,requestedAlgoCount,returnedAlgoCount,perfResults,workspace,sizeof(workspace))
232+
end
233+
val = cudnnConvolutionAlgoPerfChoose(perfResults, returnedAlgoCount[1])
234+
lock(cudnnConvolutionBwdFilterAlgoPerfCacheLock) do
235+
cudnnConvolutionBwdFilterAlgoPerfCache[key] = val
219236
end
220237
end
238+
val
221239
end
222240

223241

lib/cudnn/descriptors.jl

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,23 +30,27 @@ macro cudnnDescriptor(x, set = Symbol("cudnnSet$(x)Descriptor"))
3030
const $cache = Dict{Tuple,$sname}() # Dict is 3x faster than IdDict!
3131
const $cache_lock = ReentrantLock()
3232
function $sname(args...)
33-
lock($cache_lock) do
34-
get!($cache, args) do
35-
ptr = $tname[C_NULL]
36-
$create(ptr)
37-
$set(ptr[1], args...)
38-
d = $sname(ptr[1])
39-
finalizer(x->$destroy(x.ptr), d)
40-
return d
33+
d = lock($cache_lock) do
34+
get($cache, args, nothing)
35+
end
36+
if d === nothing
37+
ptr = $tname[C_NULL]
38+
$create(ptr)
39+
$set(ptr[1], args...)
40+
d = $sname(ptr[1])
41+
finalizer(x->$destroy(x.ptr), d)
42+
lock($cache_lock) do
43+
$cache[args] = d
4144
end
4245
end
46+
return d
4347
end
4448
end |> esc
4549
end
4650

4751

4852
"""
49-
cudnnActivationDescriptor(mode::cudnnActivationMode_t,
53+
cudnnActivationDescriptor(mode::cudnnActivationMode_t,
5054
reluNanOpt::cudnnNanPropagation_t,
5155
coef::Cfloat)
5256
"""
@@ -116,8 +120,8 @@ cudnnConvolutionDescriptor(pad::Vector{Cint},
116120

117121
"""
118122
cudnnLRNDescriptor(lrnN::Cuint,
119-
lrnAlpha::Cdouble,
120-
lrnBeta::Cdouble,
123+
lrnAlpha::Cdouble,
124+
lrnBeta::Cdouble,
121125
lrnK::Cdouble)
122126
"""
123127
@cudnnDescriptor(LRN)

lib/utils/cache.jl

Lines changed: 11 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,18 @@ end
1717
# remove a handle from the cache, or create a new one
1818
function Base.pop!(f::Function, cache::HandleCache{K,V}, key) where {K,V}
1919
function check_cache(f::Function=()->nothing)
20-
try
21-
GC.enable_finalizers(false)
22-
lock(cache.lock) do
23-
handle = if !haskey(cache.idle_handles, key) || isempty(cache.idle_handles[key])
24-
f()
25-
else
26-
pop!(cache.idle_handles[key])
27-
end
28-
29-
if handle !== nothing
30-
push!(cache.active_handles, key=>handle)
31-
end
20+
lock(cache.lock) do
21+
handle = if !haskey(cache.idle_handles, key) || isempty(cache.idle_handles[key])
22+
f()
23+
else
24+
pop!(cache.idle_handles[key])
25+
end
3226

33-
return handle
27+
if handle !== nothing
28+
push!(cache.active_handles, key=>handle)
3429
end
35-
finally
36-
GC.enable_finalizers(true)
30+
31+
return handle
3732
end
3833
end
3934

@@ -51,8 +46,7 @@ end
5146

5247
# put a handle in the cache, or destroy it if it doesn't fit
5348
function Base.push!(f::Function, cache::HandleCache{K,V}, key::K, handle::V) where {K,V}
54-
# XXX: take this lock in a normal way once we have JuliaLang/julia#35689
55-
@spinlock cache.lock begin
49+
lock(cache.lock) do
5650
delete!(cache.active_handles, key=>handle)
5751

5852
if haskey(cache.idle_handles, key)

lib/utils/threading.jl

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,4 @@
1-
export @spinlock, @lock, LazyInitialized
2-
3-
const var"@lock" = Base.var"@lock"
4-
5-
# a safe way to acquire locks from finalizers, where we can't wait (which switches tasks)
6-
macro spinlock(l, ex)
7-
quote
8-
temp = $(esc(l))
9-
while !trylock(temp)
10-
ccall(:jl_cpu_pause, Cvoid, ())
11-
# Temporary solution before we have gc transition support in codegen.
12-
ccall(:jl_gc_safepoint, Cvoid, ())
13-
# we can't yield here
14-
end
15-
try
16-
$(esc(ex))
17-
finally
18-
unlock(temp)
19-
end
20-
end
21-
end
22-
1+
export LazyInitialized
232

243
"""
254
LazyInitialized{T}()

0 commit comments

Comments
 (0)