Skip to content

Commit 6674283

Browse files
vchuravyJamesWrigley
authored andcommitted
Make worker state variable threadsafe
1 parent 5664661 commit 6674283

File tree

4 files changed

+117
-15
lines changed

4 files changed

+117
-15
lines changed

src/cluster.jl

Lines changed: 48 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,9 @@ mutable struct Worker
100100
add_msgs::Array{Any,1}
101101
@atomic gcflag::Bool
102102
state::WorkerState
103-
c_state::Condition # wait for state changes
104-
ct_time::Float64 # creation time
105-
conn_func::Any # used to setup connections lazily
103+
c_state::Threads.Condition # wait for state changes, lock for state
104+
ct_time::Float64 # creation time
105+
conn_func::Any # used to setup connections lazily
106106

107107
r_stream::IO
108108
w_stream::IO
@@ -134,7 +134,7 @@ mutable struct Worker
134134
if haskey(map_pid_wrkr, id)
135135
return map_pid_wrkr[id]
136136
end
137-
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
137+
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
138138
w.initialized = Event()
139139
register_worker(w)
140140
w
@@ -144,12 +144,16 @@ mutable struct Worker
144144
end
145145

146146
function set_worker_state(w, state)
147-
w.state = state
148-
notify(w.c_state; all=true)
147+
lock(w.c_state) do
148+
w.state = state
149+
notify(w.c_state; all=true)
150+
end
149151
end
150152

151153
function check_worker_state(w::Worker)
154+
lock(w.c_state)
152155
if w.state === W_CREATED
156+
unlock(w.c_state)
153157
if !isclusterlazy()
154158
if PGRP.topology === :all_to_all
155159
# Since higher pids connect with lower pids, the remote worker
@@ -169,6 +173,8 @@ function check_worker_state(w::Worker)
169173
errormonitor(t)
170174
wait_for_conn(w)
171175
end
176+
else
177+
unlock(w.c_state)
172178
end
173179
end
174180

@@ -187,13 +193,25 @@ function exec_conn_func(w::Worker)
187193
end
188194

189195
function wait_for_conn(w)
196+
lock(w.c_state)
190197
if w.state === W_CREATED
198+
unlock(w.c_state)
191199
timeout = worker_timeout() - (time() - w.ct_time)
192200
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
193201

194-
@async (sleep(timeout); notify(w.c_state; all=true))
195-
wait(w.c_state)
196-
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
202+
T = Threads.@spawn begin
203+
sleep($timeout)
204+
lock(w.c_state) do
205+
notify(w.c_state; all=true)
206+
end
207+
end
208+
errormonitor(T)
209+
lock(w.c_state) do
210+
wait(w.c_state)
211+
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
212+
end
213+
else
214+
unlock(w.c_state)
197215
end
198216
nothing
199217
end
@@ -491,7 +509,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
491509
while true
492510
if isempty(launched)
493511
istaskdone(t_launch) && break
494-
@async (sleep(1); notify(launch_ntfy))
512+
@async begin
513+
sleep(1)
514+
notify(launch_ntfy)
515+
end
495516
wait(launch_ntfy)
496517
end
497518

@@ -645,7 +666,12 @@ function create_worker(manager, wconfig)
645666
# require the value of config.connect_at which is set only upon connection completion
646667
for jw in PGRP.workers
647668
if (jw.id != 1) && (jw.id < w.id)
648-
(jw.state === W_CREATED) && wait(jw.c_state)
669+
# wait for wl to join
670+
lock(jw.c_state) do
671+
if jw.state === W_CREATED
672+
wait(jw.c_state)
673+
end
674+
end
649675
push!(join_list, jw)
650676
end
651677
end
@@ -668,7 +694,12 @@ function create_worker(manager, wconfig)
668694
end
669695

670696
for wl in wlist
671-
(wl.state === W_CREATED) && wait(wl.c_state)
697+
lock(wl.c_state) do
698+
if wl.state === W_CREATED
699+
# wait for wl to join
700+
wait(wl.c_state)
701+
end
702+
end
672703
push!(join_list, wl)
673704
end
674705
end
@@ -685,7 +716,11 @@ function create_worker(manager, wconfig)
685716
@async manage(w.manager, w.id, w.config, :register)
686717
# wait for rr_ntfy_join with timeout
687718
timedout = false
688-
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
719+
@async begin
720+
sleep($timeout)
721+
timedout = true
722+
put!(rr_ntfy_join, 1)
723+
end
689724
wait(rr_ntfy_join)
690725
if timedout
691726
error("worker did not connect within $timeout seconds")

src/managers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
176176
# Wait for all launches to complete.
177177
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
178178
let machine=machine, cnt=cnt
179-
@async try
179+
@async try
180180
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
181181
catch e
182182
print(stderr, "exception launching on machine $(machine) : $(e)\n")

test/distributed_exec.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1937,5 +1937,8 @@ end
19371937

19381938
# Run topology tests last after removing all workers, since a given
19391939
# cluster at any time only supports a single topology.
1940-
nprocs() > 1 && rmprocs(workers())
1940+
if nprocs() > 1
1941+
rmprocs(workers())
1942+
end
1943+
include("threads.jl")
19411944
include("topology.jl")

test/threads.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Test
2+
using Distributed, Base.Threads
3+
using Base.Iterators: product
4+
5+
exeflags = ("--startup-file=no",
6+
"--check-bounds=yes",
7+
"--depwarn=error",
8+
"--threads=2")
9+
10+
function call_on(f, wid, tid)
11+
remotecall(wid) do
12+
t = Task(f)
13+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
14+
schedule(t)
15+
@assert threadid(t) == tid
16+
t
17+
end
18+
end
19+
20+
# Run function on process holding the data to only serialize the result of f.
21+
# This becomes useful for things that cannot be serialized (e.g. running tasks)
22+
# or that would be unnecessarily big if serialized.
23+
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)
24+
25+
isdone(rr) = fetch_from_owner(istaskdone, rr)
26+
isfailed(rr) = fetch_from_owner(istaskfailed, rr)
27+
28+
@testset "RemoteChannel allows put!/take! from thread other than 1" begin
29+
ws = ts = product(1:2, 1:2)
30+
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
31+
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
32+
# We want (the default) laziness, so that we wait for `Worker.c_state`!
33+
procs_added = addprocs(2; exeflags, lazy=true)
34+
@everywhere procs_added using Base.Threads
35+
36+
p1 = procs_added[w1]
37+
p2 = procs_added[w2]
38+
chan_id = first(procs_added)
39+
chan = RemoteChannel(chan_id)
40+
send = call_on(p1, t1) do
41+
put!(chan, nothing)
42+
end
43+
recv = call_on(p2, t2) do
44+
take!(chan)
45+
end
46+
47+
# Wait on the spawned tasks on the owner. Note that we use
48+
# timedwait() instead of @sync to avoid deadlocks.
49+
t1 = Threads.@spawn fetch_from_owner(wait, recv)
50+
t2 = Threads.@spawn fetch_from_owner(wait, send)
51+
@test timedwait(() -> istaskdone(t1), 5) == :ok
52+
@test timedwait(() -> istaskdone(t2), 5) == :ok
53+
54+
# Check the tasks
55+
@test isdone(send)
56+
@test isdone(recv)
57+
58+
@test !isfailed(send)
59+
@test !isfailed(recv)
60+
61+
rmprocs(procs_added)
62+
end
63+
end
64+
end

0 commit comments

Comments
 (0)