Skip to content

Commit 853974a

Browse files
vchuravyJamesWrigley
authored andcommitted
Make worker state variable threadsafe
1 parent 15f6afb commit 853974a

File tree

7 files changed

+122
-27
lines changed

7 files changed

+122
-27
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
- uses: julia-actions/julia-buildpkg@v1
5555
- uses: julia-actions/julia-runtest@v1
5656
env:
57-
JULIA_DISTRIBUTED_TESTING_STANDALONE: 1
57+
JULIA_NUM_THREADS: 4
5858
- uses: julia-actions/julia-processcoverage@v1
5959
- uses: codecov/codecov-action@v4
6060
with:

src/cluster.jl

Lines changed: 50 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ mutable struct Worker
9999
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
100100
add_msgs::Array{Any,1}
101101
@atomic gcflag::Bool
102-
state::WorkerState
103-
c_state::Condition # wait for state changes
104-
ct_time::Float64 # creation time
105-
conn_func::Any # used to setup connections lazily
102+
@atomic state::WorkerState
103+
c_state::Threads.Condition # wait for state changes, lock for state
104+
ct_time::Float64 # creation time
105+
conn_func::Any # used to setup connections lazily
106106

107107
r_stream::IO
108108
w_stream::IO
@@ -134,7 +134,7 @@ mutable struct Worker
134134
if haskey(map_pid_wrkr, id)
135135
return map_pid_wrkr[id]
136136
end
137-
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
137+
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
138138
w.initialized = Event()
139139
register_worker(w)
140140
w
@@ -144,12 +144,14 @@ mutable struct Worker
144144
end
145145

146146
function set_worker_state(w, state)
147-
w.state = state
148-
notify(w.c_state; all=true)
147+
lock(w.c_state) do
148+
@atomic w.state = state
149+
notify(w.c_state; all=true)
150+
end
149151
end
150152

151153
function check_worker_state(w::Worker)
152-
if w.state === W_CREATED
154+
if (@atomic w.state) === W_CREATED
153155
if !isclusterlazy()
154156
if PGRP.topology === :all_to_all
155157
# Since higher pids connect with lower pids, the remote worker
@@ -170,6 +172,7 @@ function check_worker_state(w::Worker)
170172
wait_for_conn(w)
171173
end
172174
end
175+
return nothing
173176
end
174177

175178
exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
@@ -187,13 +190,21 @@ function exec_conn_func(w::Worker)
187190
end
188191

189192
function wait_for_conn(w)
190-
if w.state === W_CREATED
193+
if (@atomic w.state) === W_CREATED
191194
timeout = worker_timeout() - (time() - w.ct_time)
192195
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
193196

194-
@async (sleep(timeout); notify(w.c_state; all=true))
195-
wait(w.c_state)
196-
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
197+
T = Threads.@spawn begin
198+
sleep($timeout)
199+
lock(w.c_state) do
200+
notify(w.c_state; all=true)
201+
end
202+
end
203+
errormonitor(T)
204+
lock(w.c_state) do
205+
wait(w.c_state)
206+
(@atomic w.state) === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
207+
end
197208
end
198209
nothing
199210
end
@@ -491,7 +502,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
491502
while true
492503
if isempty(launched)
493504
istaskdone(t_launch) && break
494-
@async (sleep(1); notify(launch_ntfy))
505+
@async begin
506+
sleep(1)
507+
notify(launch_ntfy)
508+
end
495509
wait(launch_ntfy)
496510
end
497511

@@ -645,7 +659,12 @@ function create_worker(manager, wconfig)
645659
# require the value of config.connect_at which is set only upon connection completion
646660
for jw in PGRP.workers
647661
if (jw.id != 1) && (jw.id < w.id)
648-
(jw.state === W_CREATED) && wait(jw.c_state)
662+
lock(jw.c_state) do
663+
# wait for wl to join
664+
if (@atomic jw.state) === W_CREATED
665+
wait(jw.c_state)
666+
end
667+
end
649668
push!(join_list, jw)
650669
end
651670
end
@@ -668,7 +687,12 @@ function create_worker(manager, wconfig)
668687
end
669688

670689
for wl in wlist
671-
(wl.state === W_CREATED) && wait(wl.c_state)
690+
lock(wl.c_state) do
691+
if (@atomic wl.state) === W_CREATED
692+
# wait for wl to join
693+
wait(wl.c_state)
694+
end
695+
end
672696
push!(join_list, wl)
673697
end
674698
end
@@ -685,7 +709,11 @@ function create_worker(manager, wconfig)
685709
@async manage(w.manager, w.id, w.config, :register)
686710
# wait for rr_ntfy_join with timeout
687711
timedout = false
688-
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
712+
@async begin
713+
sleep($timeout)
714+
timedout = true
715+
put!(rr_ntfy_join, 1)
716+
end
689717
wait(rr_ntfy_join)
690718
if timedout
691719
error("worker did not connect within $timeout seconds")
@@ -870,7 +898,7 @@ function nprocs()
870898
n = length(PGRP.workers)
871899
# filter out workers in the process of being setup/shutdown.
872900
for jw in PGRP.workers
873-
if !isa(jw, LocalProcess) && (jw.state !== W_CONNECTED)
901+
if !isa(jw, LocalProcess) && ((@atomic jw.state) !== W_CONNECTED)
874902
n = n - 1
875903
end
876904
end
@@ -921,7 +949,7 @@ julia> procs()
921949
function procs()
922950
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
923951
# filter out workers in the process of being setup/shutdown.
924-
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
952+
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
925953
else
926954
return Int[x.id for x in PGRP.workers]
927955
end
@@ -930,7 +958,7 @@ end
930958
function id_in_procs(id) # faster version of `id in procs()`
931959
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
932960
for x in PGRP.workers
933-
if (x.id::Int) == id && (isa(x, LocalProcess) || (x::Worker).state === W_CONNECTED)
961+
if (x.id::Int) == id && (isa(x, LocalProcess) || (@atomic (x::Worker).state) === W_CONNECTED)
934962
return true
935963
end
936964
end
@@ -952,7 +980,7 @@ Specifically all workers bound to the same ip-address as `pid` are returned.
952980
"""
953981
function procs(pid::Integer)
954982
if myid() == 1
955-
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
983+
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
956984
if (pid == 1) || (isa(map_pid_wrkr[pid].manager, LocalManager))
957985
Int[x.id for x in filter(w -> (w.id==1) || (isa(w.manager, LocalManager)), all_workers)]
958986
else
@@ -1059,11 +1087,11 @@ function _rmprocs(pids, waitfor)
10591087

10601088
start = time_ns()
10611089
while (time_ns() - start) < waitfor*1e9
1062-
all(w -> w.state === W_TERMINATED, rmprocset) && break
1090+
all(w -> (@atomic w.state) === W_TERMINATED, rmprocset) && break
10631091
sleep(min(0.1, waitfor - (time_ns() - start)/1e9))
10641092
end
10651093

1066-
unremoved = [wrkr.id for wrkr in filter(w -> w.state !== W_TERMINATED, rmprocset)]
1094+
unremoved = [wrkr.id for wrkr in filter(w -> (@atomic w.state) !== W_TERMINATED, rmprocset)]
10671095
if length(unremoved) > 0
10681096
estr = string("rmprocs: pids ", unremoved, " not terminated after ", waitfor, " seconds.")
10691097
throw(ErrorException(estr))

src/managers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
183183
# Wait for all launches to complete.
184184
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
185185
let machine=machine, cnt=cnt
186-
@async try
186+
@async try
187187
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
188188
catch e
189189
print(stderr, "exception launching on machine $(machine) : $(e)\n")

src/messages.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ end
194194
function flush_gc_msgs()
195195
try
196196
for w in (PGRP::ProcessGroup).workers
197-
if isa(w,Worker) && (w.state == W_CONNECTED) && w.gcflag
197+
if isa(w,Worker) && ((@atomic w.state) == W_CONNECTED) && w.gcflag
198198
flush_gc_msgs(w)
199199
end
200200
end

src/process_messages.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
222222
println(stderr, "Process($(myid())) - Unknown remote, closing connection.")
223223
elseif !(wpid in map_del_wrkr)
224224
werr = worker_from_id(wpid)
225-
oldstate = werr.state
225+
oldstate = @atomic werr.state
226226
set_worker_state(werr, W_TERMINATED)
227227

228228
# If unhandleable error occurred talking to pid 1, exit

test/distributed_exec.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1991,5 +1991,8 @@ end
19911991

19921992
# Run topology tests last after removing all workers, since a given
19931993
# cluster at any time only supports a single topology.
1994-
nprocs() > 1 && rmprocs(workers())
1994+
if nprocs() > 1
1995+
rmprocs(workers())
1996+
end
1997+
include("threads.jl")
19951998
include("topology.jl")

test/threads.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
using Test
2+
using DistributedNext, Base.Threads
3+
using Base.Iterators: product
4+
5+
exeflags = ("--startup-file=no",
6+
"--check-bounds=yes",
7+
"--depwarn=error",
8+
"--threads=2")
9+
10+
function call_on(f, wid, tid)
11+
remotecall(wid) do
12+
t = Task(f)
13+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid - 1)
14+
schedule(t)
15+
@assert threadid(t) == tid
16+
t
17+
end
18+
end
19+
20+
# Run function on process holding the data to only serialize the result of f.
21+
# This becomes useful for things that cannot be serialized (e.g. running tasks)
22+
# or that would be unnecessarily big if serialized.
23+
fetch_from_owner(f, rr) = remotecall_fetch(f fetch, rr.where, rr)
24+
25+
isdone(rr) = fetch_from_owner(istaskdone, rr)
26+
isfailed(rr) = fetch_from_owner(istaskfailed, rr)
27+
28+
@testset "RemoteChannel allows put!/take! from thread other than 1" begin
29+
ws = ts = product(1:2, 1:2)
30+
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
31+
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
32+
# We want (the default) laziness, so that we wait for `Worker.c_state`!
33+
procs_added = addprocs(2; exeflags, lazy=true)
34+
@everywhere procs_added using Base.Threads
35+
36+
p1 = procs_added[w1]
37+
p2 = procs_added[w2]
38+
chan_id = first(procs_added)
39+
chan = RemoteChannel(chan_id)
40+
send = call_on(p1, t1) do
41+
put!(chan, nothing)
42+
end
43+
recv = call_on(p2, t2) do
44+
take!(chan)
45+
end
46+
47+
# Wait on the spawned tasks on the owner. Note that we use
48+
# timedwait() instead of @sync to avoid deadlocks.
49+
t1 = Threads.@spawn fetch_from_owner(wait, recv)
50+
t2 = Threads.@spawn fetch_from_owner(wait, send)
51+
@test timedwait(() -> istaskdone(t1), 5) == :ok
52+
@test timedwait(() -> istaskdone(t2), 5) == :ok
53+
54+
# Check the tasks
55+
@test isdone(send)
56+
@test isdone(recv)
57+
58+
@test !isfailed(send)
59+
@test !isfailed(recv)
60+
61+
rmprocs(procs_added)
62+
end
63+
end
64+
end

0 commit comments

Comments
 (0)