Skip to content

Commit 444aa87

Browse files
Make Distributed.Worker threadsafe (#37905)
1 parent db8e940 commit 444aa87

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

stdlib/Distributed/src/cluster.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ mutable struct Worker
9999
add_msgs::Array{Any,1}
100100
gcflag::Bool
101101
state::WorkerState
102-
c_state::Condition # wait for state changes
102+
c_state::Event # wait for state changes
103103
ct_time::Float64 # creation time
104104
conn_func::Any # used to setup connections lazily
105105

@@ -133,7 +133,7 @@ mutable struct Worker
133133
if haskey(map_pid_wrkr, id)
134134
return map_pid_wrkr[id]
135135
end
136-
w=new(id, [], [], false, W_CREATED, Condition(), time(), conn_func)
136+
w=new(id, [], [], false, W_CREATED, Event(), time(), conn_func)
137137
w.initialized = Event()
138138
register_worker(w)
139139
w
@@ -144,7 +144,7 @@ end
144144

145145
function set_worker_state(w, state)
146146
w.state = state
147-
notify(w.c_state; all=true)
147+
notify(w.c_state)
148148
end
149149

150150
function check_worker_state(w::Worker)
@@ -189,7 +189,7 @@ function wait_for_conn(w)
189189
timeout = worker_timeout() - (time() - w.ct_time)
190190
timeout <= 0 && error("peer $(w.id) has not connected to $(myid())")
191191

192-
@async (sleep(timeout); notify(w.c_state; all=true))
192+
@async (sleep(timeout); notify(w.c_state))
193193
wait(w.c_state)
194194
w.state === W_CREATED && error("peer $(w.id) didn't connect to $(myid()) within $timeout seconds")
195195
end

stdlib/Distributed/test/distributed_exec.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1711,4 +1711,5 @@ include("splitrange.jl")
17111711
# Run topology tests last after removing all workers, since a given
17121712
# cluster at any time only supports a single topology.
17131713
rmprocs(workers())
1714+
include("threads.jl")
17141715
include("topology.jl")

stdlib/Distributed/test/threads.jl

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
using Test
2+
using Distributed, Base.Threads
3+
using Base.Iterators: product
4+
5+
exeflags = ("--startup-file=no",
6+
"--check-bounds=yes",
7+
"--depwarn=error",
8+
"--threads=2")
9+
10+
function call_on(f, wid, tid)
11+
remotecall(wid) do
12+
t = Task(f)
13+
ccall(:jl_set_task_tid, Cvoid, (Any, Cint), t, tid-1)
14+
schedule(t)
15+
@assert threadid(t) == tid
16+
t
17+
end
18+
end
19+
20+
# Run function on process holding the data to only serialize the result of f.
21+
# This becomes useful for things that cannot be serialized (e.g. running tasks)
22+
# or that would be unnecessarily big if serialized.
23+
fetch_from_owner(f, rr) = remotecall_fetch(ffetch, rr.where, rr)
24+
25+
isdone(rr) = fetch_from_owner(istaskdone, rr)
26+
isfailed(rr) = fetch_from_owner(istaskfailed, rr)
27+
28+
@testset "RemoteChannel allows put!/take! from thread other than 1" begin
29+
ws = ts = product(1:2, 1:2)
30+
timeout = 10.0
31+
@testset "from worker $w1 to $w2 via 1" for (w1, w2) in ws
32+
@testset "from thread $w1.$t1 to $w2.$t2" for (t1, t2) in ts
33+
# We want (the default) lazyness, so that we wait for `Worker.c_state`!
34+
procs_added = addprocs(2; exeflags, lazy=true)
35+
@everywhere procs_added using Base.Threads
36+
p1 = procs_added[w1]
37+
p2 = procs_added[w2]
38+
chan_id = first(procs_added)
39+
chan = RemoteChannel(chan_id)
40+
send = call_on(p1, t1) do
41+
put!(chan, nothing)
42+
end
43+
recv = call_on(p2, t2) do
44+
take!(chan)
45+
end
46+
timedwait(() -> isdone(send) && isdone(recv), timeout)
47+
@test isdone(send)
48+
@test isdone(recv)
49+
@test !isfailed(send)
50+
@test !isfailed(recv)
51+
rmprocs(procs_added)
52+
end
53+
end
54+
end

0 commit comments

Comments
 (0)