Skip to content

Commit 9251554

Browse files
authored
Minor improvements to nonblocking synchronization. (#2272)
1 parent 887cc47 commit 9251554

File tree

1 file changed

+18
-7
lines changed

1 file changed

+18
-7
lines changed

lib/cudadrv/synchronization.jl

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ const SyncObject = Union{CuContext, CuStream, CuEvent}
110110
const MAX_SYNC_THREADS = 4
111111
const sync_channels = Array{BidirectionalChannel{SyncObject,CUresult}}(undef, MAX_SYNC_THREADS)
112112
const sync_channel_cursor = Threads.Atomic{UInt32}(1)
113+
const sync_channel_lock = Base.ReentrantLock()
113114

114115
function synchronization_worker(data)
115116
i = Int(data)
@@ -133,15 +134,25 @@ function synchronization_worker(data)
133134
end
134135

135136
@noinline function create_synchronization_worker(i)
136-
sync_channels[i] = BidirectionalChannel{SyncObject,CUresult}()
137-
# should be safe to assign before threads are running;
138-
# any user will just submit work that makes it block
137+
lock(sync_channel_lock) do
138+
# test and test-and-set
139+
if isassigned(sync_channels, i)
140+
return
141+
end
142+
143+
# should be safe to assign before threads are running;
144+
# any user will just submit work that makes it block
145+
sync_channels[i] = BidirectionalChannel{SyncObject,CUresult}()
139146

140-
# we don't know what the size of uv_thread_t is, so reserve enough space
141-
tid = Ref{NTuple{32, UInt8}}(ntuple(i -> 0, 32))
147+
# we don't know what the size of uv_thread_t is, so reserve enough space
148+
tid = Ref{NTuple{32, UInt8}}(ntuple(i -> 0, 32))
142149

143-
cb = @cfunction(synchronization_worker, Cvoid, (Ptr{Cvoid},))
144-
@ccall uv_thread_create(tid::Ptr{Cvoid}, cb::Ptr{Cvoid}, Ptr{Cvoid}(i)::Ptr{Cvoid})::Int32
150+
cb = @cfunction(synchronization_worker, Cvoid, (Ptr{Cvoid},))
151+
err = @ccall uv_thread_create(tid::Ptr{Cvoid}, cb::Ptr{Cvoid}, Ptr{Cvoid}(i)::Ptr{Cvoid})::Cint
152+
err == 0 || Base.uv_error("uv_thread_create", err)
153+
@ccall uv_thread_detach(tid::Ptr{Cvoid})::Cint
154+
err == 0 || Base.uv_error("uv_thread_detach", err)
155+
end
145156

146157
return
147158
end

0 commit comments

Comments
 (0)