Skip to content

Commit efe8e4d

Browse files
committed
Add support for worker state callbacks
1 parent d5fd837 commit efe8e4d

File tree

4 files changed

+264
-14
lines changed

4 files changed

+264
-14
lines changed

docs/src/_changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This documents notable changes in DistributedNext.jl. The format is based on
1818
incompatibilities from both libraries being used simultaneously ([#10]).
1919
- [`other_workers()`](@ref) and [`other_procs()`](@ref) were implemented and
2020
exported ([#18]).
21+
- Implemented callback support for workers being added/removed etc ([#17]).
2122

2223
## [v1.0.0] - 2024-12-02
2324

docs/src/index.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ DistributedNext.cluster_cookie()
5252
DistributedNext.cluster_cookie(::Any)
5353
```
5454

55+
## Callbacks
56+
57+
```@docs
58+
DistributedNext.add_worker_starting_callback
59+
DistributedNext.remove_worker_starting_callback
60+
DistributedNext.add_worker_started_callback
61+
DistributedNext.remove_worker_started_callback
62+
DistributedNext.add_worker_exiting_callback
63+
DistributedNext.remove_worker_exiting_callback
64+
DistributedNext.add_worker_exited_callback
65+
DistributedNext.remove_worker_exited_callback
66+
```
67+
5568
## Cluster Manager Interface
5669

5770
This interface provides a mechanism to launch and manage Julia workers on different cluster environments.

src/cluster.jl

Lines changed: 191 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,20 +472,28 @@ end
472472
```
473473
"""
474474
function addprocs(manager::ClusterManager; kwargs...)
475+
params = merge(default_addprocs_params(manager), Dict{Symbol, Any}(kwargs))
476+
475477
init_multi()
476478

477479
cluster_mgmt_from_master_check()
478480

479-
lock(worker_lock)
480-
try
481-
addprocs_locked(manager::ClusterManager; kwargs...)
482-
finally
483-
unlock(worker_lock)
484-
end
481+
# Call worker-starting callbacks
482+
warning_interval = params[:callback_warning_interval]
483+
_run_callbacks_concurrently("worker-starting", worker_starting_callbacks,
484+
warning_interval, [(manager, params)])
485+
486+
# Add new workers
487+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)
488+
489+
# Call worker-started callbacks
490+
_run_callbacks_concurrently("worker-started", worker_started_callbacks,
491+
warning_interval, new_workers)
492+
493+
return new_workers
485494
end
486495

487-
function addprocs_locked(manager::ClusterManager; kwargs...)
488-
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
496+
function addprocs_locked(manager::ClusterManager, params)
489497
topology(Symbol(params[:topology]))
490498

491499
if PGRP.topology !== :all_to_all
@@ -572,7 +580,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
572580
:exeflags => ``,
573581
:env => [],
574582
:enable_threaded_blas => false,
575-
:lazy => true)
583+
:lazy => true,
584+
:callback_warning_interval => 10)
576585

577586

578587
function setup_launched_worker(manager, wconfig, launched_q)
@@ -870,13 +879,151 @@ const HDR_COOKIE_LEN=16
870879
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
871880
const map_sock_wrkr = IdDict()
872881
const map_del_wrkr = Set{Int}()
882+
const worker_starting_callbacks = Dict{Any, Base.Callable}()
883+
const worker_started_callbacks = Dict{Any, Base.Callable}()
884+
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
885+
const worker_exited_callbacks = Dict{Any, Base.Callable}()
873886

874887
# whether process is a master or worker in a distributed setup
875888
myrole() = LPROCROLE[]
876889
function myrole!(proctype::Symbol)
877890
LPROCROLE[] = proctype
878891
end
879892

893+
# Callbacks
894+
895+
function _run_callbacks_concurrently(callbacks_name, callbacks_dict, warning_interval, arglist)
896+
callback_tasks = Dict{Any, Task}()
897+
for args in arglist
898+
for (name, callback) in callbacks_dict
899+
callback_tasks[name] = Threads.@spawn callback(args...)
900+
end
901+
end
902+
903+
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
904+
while timedwait(() -> isempty(running_callbacks()), warning_interval) === :timed_out
905+
callbacks_str = join(running_callbacks(), ", ")
906+
@warn "Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str)"
907+
end
908+
909+
# Wait on the tasks so that exceptions bubble up
910+
wait.(values(callback_tasks))
911+
end
912+
913+
function _add_callback(f, key, dict; arg_types=Tuple{Int})
914+
desired_signature = "f(" * join(["::$(t)" for t in arg_types.types], ", ") * ")"
915+
916+
if !hasmethod(f, arg_types)
917+
throw(ArgumentError("Callback function is invalid, it must be able to be called with these argument types: $(desired_signature)"))
918+
elseif haskey(dict, key)
919+
throw(ArgumentError("A callback function with key '$(key)' already exists"))
920+
end
921+
922+
if isnothing(key)
923+
key = Symbol(gensym(), nameof(f))
924+
end
925+
926+
dict[key] = f
927+
return key
928+
end
929+
930+
_remove_callback(key, dict) = delete!(dict, key)
931+
932+
"""
933+
add_worker_starting_callback(f::Base.Callable; key=nothing)
934+
935+
Register a callback to be called on the master process immediately before new
936+
workers are started. The callback `f` will be called with the `ClusterManager`
937+
instance that is being used and a dictionary of parameters related to adding
938+
workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
939+
`manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
940+
in DistributedNext are not fully documented yet, see the
941+
[managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
942+
file for their definitions.
943+
944+
!!! warning
945+
Adding workers can fail so it is not guaranteed that the workers requested
946+
will exist.
947+
948+
The worker-starting callbacks will be executed concurrently. If one throws an
949+
exception it will not be caught and will bubble up through [`addprocs`](@ref).
950+
951+
Keep in mind that the callbacks will add to the time taken to launch workers; so
952+
try to either keep the callbacks fast to execute, or do the actual work
953+
asynchronously by spawning a task in the callback (beware of race conditions if
954+
you do this).
955+
"""
956+
add_worker_starting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_starting_callbacks;
957+
arg_types=Tuple{ClusterManager, Dict})
958+
959+
remove_worker_starting_callback(key) = _remove_callback(key, worker_starting_callbacks)
960+
961+
"""
962+
add_worker_started_callback(f::Base.Callable; key=nothing)
963+
964+
Register a callback to be called on the master process whenever a worker is
965+
added. The callback will be called with the added worker ID,
966+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
967+
not specified.
968+
969+
The worker-started callbacks will be executed concurrently. If one throws an
970+
exception it will not be caught and will bubble up through [`addprocs()`](@ref).
971+
972+
Keep in mind that the callbacks will add to the time taken to launch workers; so
973+
try to either keep the callbacks fast to execute, or do the actual
974+
initialization asynchronously by spawning a task in the callback (beware of race
975+
conditions if you do this).
976+
"""
977+
add_worker_started_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_started_callbacks)
978+
979+
"""
980+
remove_worker_started_callback(key)
981+
982+
Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
983+
"""
984+
remove_worker_started_callback(key) = _remove_callback(key, worker_started_callbacks)
985+
986+
"""
987+
add_worker_exiting_callback(f::Base.Callable; key=nothing)
988+
989+
Register a callback to be called on the master process immediately before a
990+
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
991+
worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
992+
if `key` is not specified.
993+
994+
All worker-exiting callbacks will be executed concurrently and if they don't
995+
all finish before the `callback_timeout` passed to `rmprocs()` then the process
996+
will be removed anyway.
997+
"""
998+
add_worker_exiting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exiting_callbacks)
999+
1000+
"""
1001+
remove_worker_exiting_callback(key)
1002+
1003+
Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1004+
"""
1005+
remove_worker_exiting_callback(key) = _remove_callback(key, worker_exiting_callbacks)
1006+
1007+
"""
1008+
add_worker_exited_callback(f::Base.Callable; key=nothing)
1009+
1010+
Register a callback to be called on the master process when a worker has exited
1011+
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1012+
segfaulting etc). The callback will be called with the worker ID,
1013+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
1014+
not specified.
1015+
1016+
If the callback throws an exception it will be caught and printed.
1017+
"""
1018+
add_worker_exited_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exited_callbacks)
1019+
1020+
"""
1021+
remove_worker_exited_callback(key)
1022+
1023+
Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1024+
"""
1025+
remove_worker_exited_callback(key) = _remove_callback(key, worker_exited_callbacks)
1026+
8801027
# cluster management related API
8811028
"""
8821029
myid()
@@ -1063,7 +1210,7 @@ function cluster_mgmt_from_master_check()
10631210
end
10641211

10651212
"""
1066-
rmprocs(pids...; waitfor=typemax(Int))
1213+
rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10671214
10681215
Remove the specified workers. Note that only process 1 can add or remove
10691216
workers.
@@ -1077,6 +1224,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10771224
returned. The user should call [`wait`](@ref) on the task before invoking any other
10781225
parallel calls.
10791226
1227+
The `callback_timeout` specifies how long to wait for any callbacks to execute
1228+
before continuing to remove the workers (see
1229+
[`add_worker_exiting_callback()`](@ref)).
1230+
10801231
# Examples
10811232
```julia-repl
10821233
\$ julia -p 5
@@ -1093,24 +1244,38 @@ julia> workers()
10931244
6
10941245
```
10951246
"""
1096-
function rmprocs(pids...; waitfor=typemax(Int))
1247+
function rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10971248
cluster_mgmt_from_master_check()
10981249

10991250
pids = vcat(pids...)
11001251
if waitfor == 0
1101-
t = @async _rmprocs(pids, typemax(Int))
1252+
t = @async _rmprocs(pids, typemax(Int), callback_timeout)
11021253
yield()
11031254
return t
11041255
else
1105-
_rmprocs(pids, waitfor)
1256+
_rmprocs(pids, waitfor, callback_timeout)
11061257
# return a dummy task object that user code can wait on.
11071258
return @async nothing
11081259
end
11091260
end
11101261

1111-
function _rmprocs(pids, waitfor)
1262+
function _rmprocs(pids, waitfor, callback_timeout)
11121263
lock(worker_lock)
11131264
try
1265+
# Run the callbacks
1266+
callback_tasks = Dict{Any, Task}()
1267+
for pid in pids
1268+
for (name, callback) in worker_exiting_callbacks
1269+
callback_tasks[name] = Threads.@spawn callback(pid)
1270+
end
1271+
end
1272+
1273+
if timedwait(() -> all(istaskdone.(values(callback_tasks))), callback_timeout) === :timed_out
1274+
timedout_callbacks = ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
1275+
callbacks_str = join(timedout_callbacks, ", ")
1276+
@warn "Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str)"
1277+
end
1278+
11141279
rmprocset = Union{LocalProcess, Worker}[]
11151280
for p in pids
11161281
if p == 1
@@ -1256,6 +1421,18 @@ function deregister_worker(pg, pid)
12561421
delete!(pg.refs, id)
12571422
end
12581423
end
1424+
1425+
# Call callbacks on the master
1426+
if myid() == 1
1427+
for (name, callback) in worker_exited_callbacks
1428+
try
1429+
callback(pid)
1430+
catch ex
1431+
@error "Error when running worker-exited callback '$(name)'" exception=(ex, catch_backtrace())
1432+
end
1433+
end
1434+
end
1435+
12591436
return
12601437
end
12611438

test/distributed_exec.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

33
using DistributedNext, Random, Serialization, Sockets
4+
import DistributedNext
45
import DistributedNext: launch, manage
56

67

@@ -1934,6 +1935,64 @@ include("splitrange.jl")
19341935
end
19351936
end
19361937

1938+
@testset "Worker state callbacks" begin
1939+
rmprocs(other_workers())
1940+
1941+
# Adding a callback with an invalid signature should fail
1942+
@test_throws ArgumentError DistributedNext.add_worker_started_callback(() -> nothing)
1943+
1944+
# Smoke test to ensure that all the callbacks are executed
1945+
starting_managers = []
1946+
started_workers = Int[]
1947+
exiting_workers = Int[]
1948+
exited_workers = Int[]
1949+
starting_key = DistributedNext.add_worker_starting_callback((manager, kwargs) -> push!(starting_managers, manager))
1950+
started_key = DistributedNext.add_worker_started_callback(pid -> (push!(started_workers, pid); error("foo")))
1951+
exiting_key = DistributedNext.add_worker_exiting_callback(pid -> push!(exiting_workers, pid))
1952+
exited_key = DistributedNext.add_worker_exited_callback(pid -> push!(exited_workers, pid))
1953+
1954+
# Test that the worker-started exception bubbles up
1955+
@test_throws TaskFailedException addprocs(1)
1956+
1957+
pid = only(workers())
1958+
@test only(starting_managers) isa DistributedNext.LocalManager
1959+
@test started_workers == [pid]
1960+
rmprocs(workers())
1961+
@test exiting_workers == [pid]
1962+
@test exited_workers == [pid]
1963+
1964+
# Trying to reset an existing callback should fail
1965+
@test_throws ArgumentError DistributedNext.add_worker_started_callback(Returns(nothing); key=started_key)
1966+
1967+
# Remove the callbacks
1968+
DistributedNext.remove_worker_starting_callback(starting_key)
1969+
DistributedNext.remove_worker_started_callback(started_key)
1970+
DistributedNext.remove_worker_exiting_callback(exiting_key)
1971+
DistributedNext.remove_worker_exited_callback(exited_key)
1972+
1973+
# Test that the worker-exiting `callback_timeout` option works and that we
1974+
# get warnings about slow worker-started callbacks.
1975+
event = Base.Event()
1976+
callback_task = nothing
1977+
started_key = DistributedNext.add_worker_started_callback(_ -> sleep(0.5))
1978+
exiting_key = DistributedNext.add_worker_exiting_callback(_ -> (callback_task = current_task(); wait(event)))
1979+
1980+
@test_logs (:warn, r"Waiting for these worker-started callbacks.+") match_mode=:any addprocs(1; callback_warning_interval=0.05)
1981+
DistributedNext.remove_worker_started_callback(started_key)
1982+
1983+
@test_logs (:warn, r"Some worker-exiting callbacks have not yet finished.+") rmprocs(workers(); callback_timeout=0.5)
1984+
DistributedNext.remove_worker_exiting_callback(exiting_key)
1985+
1986+
notify(event)
1987+
wait(callback_task)
1988+
1989+
# Test that the initial callbacks were indeed removed
1990+
@test length(starting_managers) == 1
1991+
@test length(started_workers) == 1
1992+
@test length(exiting_workers) == 1
1993+
@test length(exited_workers) == 1
1994+
end
1995+
19371996
# Run topology tests last after removing all workers, since a given
19381997
# cluster at any time only supports a single topology.
19391998
if nprocs() > 1

0 commit comments

Comments
 (0)