Skip to content

Commit b9a8000

Browse files
authored
Merge pull request #4 from JuliaParallel/jps/threadsafe_workerstate
Such threadsafe, much wow
2 parents 76df474 + c1a3be8 commit b9a8000

File tree

14 files changed

+1907
-1767
lines changed

14 files changed

+1907
-1767
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ jobs:
5454
- uses: julia-actions/julia-buildpkg@v1
5555
- uses: julia-actions/julia-runtest@v1
5656
env:
57-
JULIA_DISTRIBUTED_TESTING_STANDALONE: 1
57+
JULIA_NUM_THREADS: 4
5858
- uses: julia-actions/julia-processcoverage@v1
5959
- uses: codecov/codecov-action@v4
6060
with:

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
docs/src/changelog.md
2+
Manifest.toml
3+
*.swp

docs/src/_changelog.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@ This documents notable changes in DistributedNext.jl. The format is based on
1212
### Fixed
1313
- Fixed behaviour of `isempty(::RemoteChannel)`, which previously had the
1414
side-effect of taking an element from the channel ([#3]).
15+
- Improved thread-safety, such that it should be safe to start workers with
16+
multiple threads and send messages between them ([#4]).
1517

1618
### Changed
1719
- Added a `project` argument to [`addprocs(::AbstractVector)`](@ref) to specify

src/cluster.jl

Lines changed: 74 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -99,10 +99,10 @@ mutable struct Worker
9999
del_msgs::Array{Any,1} # XXX: Could del_msgs and add_msgs be Channels?
100100
add_msgs::Array{Any,1}
101101
@atomic gcflag::Bool
102-
state::WorkerState
103-
c_state::Condition # wait for state changes
104-
ct_time::Float64 # creation time
105-
conn_func::Any # used to setup connections lazily
102+
@atomic state::WorkerState
103+
c_state::Threads.Condition # wait for state changes, lock for state
104+
ct_time::Float64 # creation time
105+
conn_func::Any # used to setup connections lazily
106106

107107
r_stream::IO
108108
w_stream::IO
@@ -134,7 +134,7 @@ mutable struct Worker
134134
if haskey(map_pid_wrkr, id)
135135
return map_pid_wrkr[id]
136136
end
137-
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Condition(), time(), conn_func)
137+
w=new(id, Threads.ReentrantLock(), [], [], false, W_CREATED, Threads.Condition(), time(), conn_func)
138138
w.initialized = Event()
139139
register_worker(w)
140140
w
@@ -144,12 +144,14 @@ mutable struct Worker
144144
end
145145

146146
function set_worker_state(w, state)
147-
w.state = state
148-
notify(w.c_state; all=true)
147+
lock(w.c_state) do
148+
@atomic w.state = state
149+
notify(w.c_state; all=true)
150+
end
149151
end
150152

151153
function check_worker_state(w::Worker)
152-
if w.state === W_CREATED
154+
if (@atomic w.state) === W_CREATED
153155
if !isclusterlazy()
154156
if PGRP.topology === :all_to_all
155157
# Since higher pids connect with lower pids, the remote worker
@@ -170,6 +172,7 @@ function check_worker_state(w::Worker)
170172
wait_for_conn(w)
171173
end
172174
end
175+
return nothing
173176
end
174177

175178
exec_conn_func(id::Int) = exec_conn_func(worker_from_id(id)::Worker)
@@ -187,13 +190,21 @@ function exec_conn_func(w::Worker)
187190
end
188191

189192
function wait_for_conn(w)
190-
if w.state === W_CREATED
193+
if (@atomic w.state) === W_CREATED
191194
timeout = worker_timeout() - (time() - w.ct_time)
192195
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
193196

194-
@async (sleep(timeout); notify(w.c_state; all=true))
195-
wait(w.c_state)
196-
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
197+
T = Threads.@spawn begin
198+
sleep($timeout)
199+
lock(w.c_state) do
200+
notify(w.c_state; all=true)
201+
end
202+
end
203+
errormonitor(T)
204+
lock(w.c_state) do
205+
wait(w.c_state)
206+
(@atomic w.state) === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
207+
end
197208
end
198209
nothing
199210
end
@@ -491,7 +502,10 @@ function addprocs_locked(manager::ClusterManager; kwargs...)
491502
while true
492503
if isempty(launched)
493504
istaskdone(t_launch) && break
494-
@async (sleep(1); notify(launch_ntfy))
505+
@async begin
506+
sleep(1)
507+
notify(launch_ntfy)
508+
end
495509
wait(launch_ntfy)
496510
end
497511

@@ -645,7 +659,12 @@ function create_worker(manager, wconfig)
645659
# require the value of config.connect_at which is set only upon connection completion
646660
for jw in PGRP.workers
647661
if (jw.id != 1) && (jw.id < w.id)
648-
(jw.state === W_CREATED) && wait(jw.c_state)
662+
lock(jw.c_state) do
663+
# wait for wl to join
664+
if (@atomic jw.state) === W_CREATED
665+
wait(jw.c_state)
666+
end
667+
end
649668
push!(join_list, jw)
650669
end
651670
end
@@ -668,7 +687,12 @@ function create_worker(manager, wconfig)
668687
end
669688

670689
for wl in wlist
671-
(wl.state === W_CREATED) && wait(wl.c_state)
690+
lock(wl.c_state) do
691+
if (@atomic wl.state) === W_CREATED
692+
# wait for wl to join
693+
wait(wl.c_state)
694+
end
695+
end
672696
push!(join_list, wl)
673697
end
674698
end
@@ -682,10 +706,16 @@ function create_worker(manager, wconfig)
682706
join_message = JoinPGRPMsg(w.id, all_locs, PGRP.topology, enable_threaded_blas, isclusterlazy())
683707
send_msg_now(w, MsgHeader(RRID(0,0), ntfy_oid), join_message)
684708

685-
@async manage(w.manager, w.id, w.config, :register)
709+
errormonitor(@async manage(w.manager, w.id, w.config, :register))
686710
# wait for rr_ntfy_join with timeout
687711
timedout = false
688-
@async (sleep($timeout); timedout = true; put!(rr_ntfy_join, 1))
712+
errormonitor(
713+
@async begin
714+
sleep($timeout)
715+
timedout = true
716+
put!(rr_ntfy_join, 1)
717+
end
718+
)
689719
wait(rr_ntfy_join)
690720
if timedout
691721
error("worker did not connect within $timeout seconds")
@@ -735,17 +765,20 @@ function check_master_connect()
735765
if ccall(:jl_running_on_valgrind,Cint,()) != 0
736766
return
737767
end
738-
@async begin
739-
start = time_ns()
740-
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
741-
sleep(1.0)
742-
end
743768

744-
if !haskey(map_pid_wrkr, 1)
745-
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
746-
exit(1)
769+
errormonitor(
770+
@async begin
771+
start = time_ns()
772+
while !haskey(map_pid_wrkr, 1) && (time_ns() - start) < timeout
773+
sleep(1.0)
774+
end
775+
776+
if !haskey(map_pid_wrkr, 1)
777+
print(stderr, "Master process (id 1) could not connect within $(timeout/1e9) seconds.\nexiting.\n")
778+
exit(1)
779+
end
747780
end
748-
end
781+
)
749782
end
750783

751784

@@ -870,7 +903,7 @@ function nprocs()
870903
n = length(PGRP.workers)
871904
# filter out workers in the process of being setup/shutdown.
872905
for jw in PGRP.workers
873-
if !isa(jw, LocalProcess) && (jw.state !== W_CONNECTED)
906+
if !isa(jw, LocalProcess) && ((@atomic jw.state) !== W_CONNECTED)
874907
n = n - 1
875908
end
876909
end
@@ -921,7 +954,7 @@ julia> procs()
921954
function procs()
922955
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
923956
# filter out workers in the process of being setup/shutdown.
924-
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
957+
return Int[x.id for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
925958
else
926959
return Int[x.id for x in PGRP.workers]
927960
end
@@ -930,7 +963,7 @@ end
930963
function id_in_procs(id) # faster version of `id in procs()`
931964
if myid() == 1 || (PGRP.topology === :all_to_all && !isclusterlazy())
932965
for x in PGRP.workers
933-
if (x.id::Int) == id && (isa(x, LocalProcess) || (x::Worker).state === W_CONNECTED)
966+
if (x.id::Int) == id && (isa(x, LocalProcess) || (@atomic (x::Worker).state) === W_CONNECTED)
934967
return true
935968
end
936969
end
@@ -952,7 +985,7 @@ Specifically all workers bound to the same ip-address as `pid` are returned.
952985
"""
953986
function procs(pid::Integer)
954987
if myid() == 1
955-
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || (x.state === W_CONNECTED)]
988+
all_workers = [x for x in PGRP.workers if isa(x, LocalProcess) || ((@atomic x.state) === W_CONNECTED)]
956989
if (pid == 1) || (isa(map_pid_wrkr[pid].manager, LocalManager))
957990
Int[x.id for x in filter(w -> (w.id==1) || (isa(w.manager, LocalManager)), all_workers)]
958991
else
@@ -1059,11 +1092,11 @@ function _rmprocs(pids, waitfor)
10591092

10601093
start = time_ns()
10611094
while (time_ns() - start) < waitfor*1e9
1062-
all(w -> w.state === W_TERMINATED, rmprocset) && break
1095+
all(w -> (@atomic w.state) === W_TERMINATED, rmprocset) && break
10631096
sleep(min(0.1, waitfor - (time_ns() - start)/1e9))
10641097
end
10651098

1066-
unremoved = [wrkr.id for wrkr in filter(w -> w.state !== W_TERMINATED, rmprocset)]
1099+
unremoved = [wrkr.id for wrkr in filter(w -> (@atomic w.state) !== W_TERMINATED, rmprocset)]
10671100
if length(unremoved) > 0
10681101
estr = string("rmprocs: pids ", unremoved, " not terminated after ", waitfor, " seconds.")
10691102
throw(ErrorException(estr))
@@ -1290,18 +1323,16 @@ end
12901323

12911324
using Random: randstring
12921325

1293-
let inited = false
1294-
# do initialization that's only needed when there is more than 1 processor
1295-
global function init_multi()
1296-
if !inited
1297-
inited = true
1298-
push!(Base.package_callbacks, _require_callback)
1299-
atexit(terminate_all_workers)
1300-
init_bind_addr()
1301-
cluster_cookie(randstring(HDR_COOKIE_LEN))
1302-
end
1303-
return nothing
1326+
# do initialization that's only needed when there is more than 1 processor
1327+
const inited = Threads.Atomic{Bool}(false)
1328+
function init_multi()
1329+
if !Threads.atomic_cas!(inited, false, true)
1330+
push!(Base.package_callbacks, _require_callback)
1331+
atexit(terminate_all_workers)
1332+
init_bind_addr()
1333+
cluster_cookie(randstring(HDR_COOKIE_LEN))
13041334
end
1335+
return nothing
13051336
end
13061337

13071338
function init_parallel()

src/managers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ function launch(manager::SSHManager, params::Dict, launched::Array, launch_ntfy:
183183
# Wait for all launches to complete.
184184
@sync for (i, (machine, cnt)) in enumerate(manager.machines)
185185
let machine=machine, cnt=cnt
186-
@async try
186+
@async try
187187
launch_on_machine(manager, $machine, $cnt, params, launched, launch_ntfy)
188188
catch e
189189
print(stderr, "exception launching on machine $(machine) : $(e)\n")

src/messages.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ end
194194
function flush_gc_msgs()
195195
try
196196
for w in (PGRP::ProcessGroup).workers
197-
if isa(w,Worker) && (w.state == W_CONNECTED) && w.gcflag
197+
if isa(w,Worker) && ((@atomic w.state) == W_CONNECTED) && w.gcflag
198198
flush_gc_msgs(w)
199199
end
200200
end

src/process_messages.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ function message_handler_loop(r_stream::IO, w_stream::IO, incoming::Bool)
222222
println(stderr, "Process($(myid())) - Unknown remote, closing connection.")
223223
elseif !(wpid in map_del_wrkr)
224224
werr = worker_from_id(wpid)
225-
oldstate = werr.state
225+
oldstate = @atomic werr.state
226226
set_worker_state(werr, W_TERMINATED)
227227

228228
# If unhandleable error occurred talking to pid 1, exit

0 commit comments

Comments
 (0)