Skip to content

Commit 764ceec

Browse files
committed
Add support for worker state callbacks
1 parent d5fd837 commit 764ceec

File tree

4 files changed

+284
-14
lines changed

4 files changed

+284
-14
lines changed

docs/src/_changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ This documents notable changes in DistributedNext.jl. The format is based on
1818
incompatibilities from both libraries being used simultaneously ([#10]).
1919
- [`other_workers()`](@ref) and [`other_procs()`](@ref) were implemented and
2020
exported ([#18]).
21+
- Implemented callback support for workers being added/removed etc ([#17]).
2122

2223
## [v1.0.0] - 2024-12-02
2324

docs/src/index.md

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,19 @@ DistributedNext.cluster_cookie()
5252
DistributedNext.cluster_cookie(::Any)
5353
```
5454

55+
## Callbacks
56+
57+
```@docs
58+
DistributedNext.add_worker_starting_callback
59+
DistributedNext.remove_worker_starting_callback
60+
DistributedNext.add_worker_started_callback
61+
DistributedNext.remove_worker_started_callback
62+
DistributedNext.add_worker_exiting_callback
63+
DistributedNext.remove_worker_exiting_callback
64+
DistributedNext.add_worker_exited_callback
65+
DistributedNext.remove_worker_exited_callback
66+
```
67+
5568
## Cluster Manager Interface
5669

5770
This interface provides a mechanism to launch and manage Julia workers on different cluster environments.

src/cluster.jl

Lines changed: 200 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -472,20 +472,28 @@ end
472472
```
473473
"""
474474
function addprocs(manager::ClusterManager; kwargs...)
475+
params = merge(default_addprocs_params(manager), Dict{Symbol, Any}(kwargs))
476+
475477
init_multi()
476478

477479
cluster_mgmt_from_master_check()
478480

479-
lock(worker_lock)
480-
try
481-
addprocs_locked(manager::ClusterManager; kwargs...)
482-
finally
483-
unlock(worker_lock)
484-
end
481+
# Call worker-starting callbacks
482+
warning_interval = params[:callback_warning_interval]
483+
_run_callbacks_concurrently("worker-starting", worker_starting_callbacks,
484+
warning_interval, [(manager, params)])
485+
486+
# Add new workers
487+
new_workers = @lock worker_lock addprocs_locked(manager::ClusterManager, params)
488+
489+
# Call worker-started callbacks
490+
_run_callbacks_concurrently("worker-started", worker_started_callbacks,
491+
warning_interval, new_workers)
492+
493+
return new_workers
485494
end
486495

487-
function addprocs_locked(manager::ClusterManager; kwargs...)
488-
params = merge(default_addprocs_params(manager), Dict{Symbol,Any}(kwargs))
496+
function addprocs_locked(manager::ClusterManager, params)
489497
topology(Symbol(params[:topology]))
490498

491499
if PGRP.topology !== :all_to_all
@@ -572,7 +580,8 @@ default_addprocs_params() = Dict{Symbol,Any}(
572580
:exeflags => ``,
573581
:env => [],
574582
:enable_threaded_blas => false,
575-
:lazy => true)
583+
:lazy => true,
584+
:callback_warning_interval => 10)
576585

577586

578587
function setup_launched_worker(manager, wconfig, launched_q)
@@ -870,13 +879,160 @@ const HDR_COOKIE_LEN=16
870879
const map_pid_wrkr = Dict{Int, Union{Worker, LocalProcess}}()
871880
const map_sock_wrkr = IdDict()
872881
const map_del_wrkr = Set{Int}()
882+
const worker_starting_callbacks = Dict{Any, Base.Callable}()
883+
const worker_started_callbacks = Dict{Any, Base.Callable}()
884+
const worker_exiting_callbacks = Dict{Any, Base.Callable}()
885+
const worker_exited_callbacks = Dict{Any, Base.Callable}()
873886

874887
# whether process is a master or worker in a distributed setup
875888
myrole() = LPROCROLE[]
876889
function myrole!(proctype::Symbol)
877890
LPROCROLE[] = proctype
878891
end
879892

893+
# Callbacks
894+
895+
function _run_callbacks_concurrently(callbacks_name, callbacks_dict, warning_interval, arglist)
896+
callback_tasks = Dict{Any, Task}()
897+
for args in arglist
898+
for (name, callback) in callbacks_dict
899+
callback_tasks[name] = Threads.@spawn callback(args...)
900+
end
901+
end
902+
903+
running_callbacks = () -> ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
904+
while timedwait(() -> isempty(running_callbacks()), warning_interval) === :timed_out
905+
callbacks_str = join(running_callbacks(), ", ")
906+
@warn "Waiting for these $(callbacks_name) callbacks to finish: $(callbacks_str)"
907+
end
908+
909+
# Wait on the tasks so that exceptions bubble up
910+
wait.(values(callback_tasks))
911+
end
912+
913+
function _add_callback(f, key, dict; arg_types=Tuple{Int})
914+
desired_signature = "f(" * join(["::$(t)" for t in arg_types.types], ", ") * ")"
915+
916+
if !hasmethod(f, arg_types)
917+
throw(ArgumentError("Callback function is invalid, it must be able to be called with these argument types: $(desired_signature)"))
918+
elseif haskey(dict, key)
919+
throw(ArgumentError("A callback function with key '$(key)' already exists"))
920+
end
921+
922+
if isnothing(key)
923+
key = Symbol(gensym(), nameof(f))
924+
end
925+
926+
dict[key] = f
927+
return key
928+
end
929+
930+
_remove_callback(key, dict) = delete!(dict, key)
931+
932+
"""
933+
add_worker_starting_callback(f::Base.Callable; key=nothing)
934+
935+
Register a callback to be called on the master process immediately before new
936+
workers are started. The callback `f` will be called with the `ClusterManager`
937+
instance that is being used and a dictionary of parameters related to adding
938+
workers, i.e. `f(manager, params)`. The `params` dictionary is specific to the
939+
`manager` type. Note that the `LocalManager` and `SSHManager` cluster managers
940+
in DistributedNext are not fully documented yet, see the
941+
[managers.jl](https://github.com/JuliaParallel/DistributedNext.jl/blob/master/src/managers.jl)
942+
file for their definitions.
943+
944+
!!! warning
945+
Adding workers can fail so it is not guaranteed that the workers requested
946+
will exist.
947+
948+
The worker-starting callbacks will be executed concurrently. If one throws an
949+
exception it will not be caught and will bubble up through [`addprocs`](@ref).
950+
951+
Keep in mind that the callbacks will add to the time taken to launch workers; so
952+
try to either keep the callbacks fast to execute, or do the actual work
953+
asynchronously by spawning a task in the callback (beware of race conditions if
954+
you do this).
955+
"""
956+
add_worker_starting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_starting_callbacks;
957+
arg_types=Tuple{ClusterManager, Dict})
958+
"""
959+
remove_worker_starting_callback(key)
960+
961+
Remove the callback for `key` that was added with [`add_worker_starting_callback()`](@ref).
962+
"""
963+
remove_worker_starting_callback(key) = _remove_callback(key, worker_starting_callbacks)
964+
965+
"""
966+
add_worker_started_callback(f::Base.Callable; key=nothing)
967+
968+
Register a callback to be called on the master process whenever a worker is
969+
added. The callback will be called with the added worker ID,
970+
e.g. `f(w::Int)`. Chooses and returns a unique key for the callback if `key` is
971+
not specified.
972+
973+
The worker-started callbacks will be executed concurrently. If one throws an
974+
exception it will not be caught and will bubble up through [`addprocs()`](@ref).
975+
976+
Keep in mind that the callbacks will add to the time taken to launch workers; so
977+
try to either keep the callbacks fast to execute, or do the actual
978+
initialization asynchronously by spawning a task in the callback (beware of race
979+
conditions if you do this).
980+
"""
981+
add_worker_started_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_started_callbacks)
982+
983+
"""
984+
remove_worker_started_callback(key)
985+
986+
Remove the callback for `key` that was added with [`add_worker_started_callback()`](@ref).
987+
"""
988+
remove_worker_started_callback(key) = _remove_callback(key, worker_started_callbacks)
989+
990+
"""
991+
add_worker_exiting_callback(f::Base.Callable; key=nothing)
992+
993+
Register a callback to be called on the master process immediately before a
994+
worker is removed with [`rmprocs()`](@ref). The callback will be called with the
995+
worker ID, e.g. `f(w::Int)`. Chooses and returns a unique key for the callback
996+
if `key` is not specified.
997+
998+
All worker-exiting callbacks will be executed concurrently and if they don't
999+
all finish before the `callback_timeout` passed to `rmprocs()` then the process
1000+
will be removed anyway.
1001+
"""
1002+
add_worker_exiting_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exiting_callbacks)
1003+
1004+
"""
1005+
remove_worker_exiting_callback(key)
1006+
1007+
Remove the callback for `key` that was added with [`add_worker_exiting_callback()`](@ref).
1008+
"""
1009+
remove_worker_exiting_callback(key) = _remove_callback(key, worker_exiting_callbacks)
1010+
1011+
"""
1012+
add_worker_exited_callback(f::Base.Callable; key=nothing)
1013+
1014+
Register a callback to be called on the master process when a worker has exited
1015+
for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
1016+
segfaulting etc). Chooses and returns a unique key for the callback if `key` is
1017+
not specified.
1018+
1019+
The callback will be called with the worker ID and the final
1020+
`Distributed.WorkerState` of the worker, e.g. `f(w::Int, state)`. `state` is an
1021+
enum, a value of `WorkerState_terminated` means a graceful exit and a value of
1022+
`WorkerState_exterminated` means the worker died unexpectedly.
1023+
1024+
If the callback throws an exception it will be caught and printed.
1025+
"""
1026+
add_worker_exited_callback(f::Base.Callable; key=nothing) = _add_callback(f, key, worker_exited_callbacks;
1027+
arg_types=Tuple{Int, WorkerState})
1028+
1029+
"""
1030+
remove_worker_exited_callback(key)
1031+
1032+
Remove the callback for `key` that was added with [`add_worker_exited_callback()`](@ref).
1033+
"""
1034+
remove_worker_exited_callback(key) = _remove_callback(key, worker_exited_callbacks)
1035+
8801036
# cluster management related API
8811037
"""
8821038
myid()
@@ -1063,7 +1219,7 @@ function cluster_mgmt_from_master_check()
10631219
end
10641220

10651221
"""
1066-
rmprocs(pids...; waitfor=typemax(Int))
1222+
rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10671223
10681224
Remove the specified workers. Note that only process 1 can add or remove
10691225
workers.
@@ -1077,6 +1233,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10771233
returned. The user should call [`wait`](@ref) on the task before invoking any other
10781234
parallel calls.
10791235
1236+
The `callback_timeout` specifies how long to wait for any callbacks to execute
1237+
before continuing to remove the workers (see
1238+
[`add_worker_exiting_callback()`](@ref)).
1239+
10801240
# Examples
10811241
```julia-repl
10821242
\$ julia -p 5
@@ -1093,24 +1253,38 @@ julia> workers()
10931253
6
10941254
```
10951255
"""
1096-
function rmprocs(pids...; waitfor=typemax(Int))
1256+
function rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10)
10971257
cluster_mgmt_from_master_check()
10981258

10991259
pids = vcat(pids...)
11001260
if waitfor == 0
1101-
t = @async _rmprocs(pids, typemax(Int))
1261+
t = @async _rmprocs(pids, typemax(Int), callback_timeout)
11021262
yield()
11031263
return t
11041264
else
1105-
_rmprocs(pids, waitfor)
1265+
_rmprocs(pids, waitfor, callback_timeout)
11061266
# return a dummy task object that user code can wait on.
11071267
return @async nothing
11081268
end
11091269
end
11101270

1111-
function _rmprocs(pids, waitfor)
1271+
function _rmprocs(pids, waitfor, callback_timeout)
11121272
lock(worker_lock)
11131273
try
1274+
# Run the callbacks
1275+
callback_tasks = Dict{Any, Task}()
1276+
for pid in pids
1277+
for (name, callback) in worker_exiting_callbacks
1278+
callback_tasks[name] = Threads.@spawn callback(pid)
1279+
end
1280+
end
1281+
1282+
if timedwait(() -> all(istaskdone.(values(callback_tasks))), callback_timeout) === :timed_out
1283+
timedout_callbacks = ["'$(key)'" for (key, task) in callback_tasks if !istaskdone(task)]
1284+
callbacks_str = join(timedout_callbacks, ", ")
1285+
@warn "Some worker-exiting callbacks have not yet finished, continuing to remove workers anyway. These are the callbacks still running: $(callbacks_str)"
1286+
end
1287+
11141288
rmprocset = Union{LocalProcess, Worker}[]
11151289
for p in pids
11161290
if p == 1
@@ -1256,6 +1430,18 @@ function deregister_worker(pg, pid)
12561430
delete!(pg.refs, id)
12571431
end
12581432
end
1433+
1434+
# Call callbacks on the master
1435+
if myid() == 1
1436+
for (name, callback) in worker_exited_callbacks
1437+
try
1438+
callback(pid, w.state)
1439+
catch ex
1440+
@error "Error when running worker-exited callback '$(name)'" exception=(ex, catch_backtrace())
1441+
end
1442+
end
1443+
end
1444+
12591445
return
12601446
end
12611447

0 commit comments

Comments
 (0)