Skip to content

Commit fa3be3f

Browse files
authored
use ssh multiplexing to create SSH workers with tunneling (JuliaLang/julia#34295)
* use ssh multiplexing to create SSH workers with tunneling * rename funcname * add test to check existence of ssh multiplexing master socket * add doc for ssh multiplexing in SSHWorker * Fix typo * add :multiplex option to make SSH multiplexing configurable
1 parent 850cc89 commit fa3be3f

File tree

3 files changed

+79
-20
lines changed

3 files changed

+79
-20
lines changed

src/cluster.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ The `userdata` field is used to store information for each worker by external ma
2828
2929
Some fields are used by `SSHManager` and similar managers:
3030
* `tunnel` -- `true` (use tunneling), `false` (do not use tunneling), or [`nothing`](@ref) (use default for the manager)
31+
* `multiplex` -- `true` (use SSH multiplexing for tunneling) or `false`
32+
* `forward` -- the forwarding option used for `-L` option of ssh
3133
* `bind_addr` -- the address on the remote host to bind to
3234
* `sshflags` -- flags to use in establishing the SSH connection
3335
* `max_parallel` -- the maximum number of workers to connect to in parallel on the host
@@ -58,6 +60,8 @@ mutable struct WorkerConfig
5860

5961
# SSHManager / SSH tunnel connections to workers
6062
tunnel::Union{Bool, Nothing}
63+
multiplex::Union{Bool, Nothing}
64+
forward::Union{AbstractString, Nothing}
6165
bind_addr::Union{AbstractString, Nothing}
6266
sshflags::Union{Cmd, Nothing}
6367
max_parallel::Union{Int, Nothing}
@@ -548,7 +552,7 @@ function launch_n_additional_processes(manager, frompid, fromconfig, cnt, launch
548552
(bind_addr, port) = address
549553

550554
wconfig = WorkerConfig()
551-
for x in [:host, :tunnel, :sshflags, :exeflags, :exename, :enable_threaded_blas]
555+
for x in [:host, :tunnel, :multiplex, :sshflags, :exeflags, :exename, :enable_threaded_blas]
552556
Base.setproperty!(wconfig, x, Base.getproperty(fromconfig, x))
553557
end
554558
wconfig.bind_addr = bind_addr

src/managers.jl

Lines changed: 66 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ Keyword arguments:
7171
* `tunnel`: if `true` then SSH tunneling will be used to connect to the worker from the
7272
master process. Default is `false`.
7373
74+
* `multiplex`: if `true` then SSH multiplexing is used for SSH tunneling. Default is `false`.
75+
7476
* `sshflags`: specifies additional ssh options, e.g. ```sshflags=\`-i /home/foo/bar.pem\````
7577
7678
* `max_parallel`: specifies the maximum number of workers connected to in parallel at a
@@ -113,9 +115,9 @@ This timeout can be controlled via environment variable `JULIA_WORKER_TIMEOUT`.
113115
The value of `JULIA_WORKER_TIMEOUT` on the master process specifies the number of seconds a
114116
newly launched worker waits for connection establishment.
115117
"""
116-
function addprocs(machines::AbstractVector; tunnel=false, sshflags=``, max_parallel=10, kwargs...)
118+
function addprocs(machines::AbstractVector; tunnel=false, multiplex=false, sshflags=``, max_parallel=10, kwargs...)
117119
check_addprocs_args(kwargs)
118-
addprocs(SSHManager(machines); tunnel=tunnel, sshflags=sshflags, max_parallel=max_parallel, kwargs...)
120+
addprocs(SSHManager(machines); tunnel=tunnel, multiplex=multiplex, sshflags=sshflags, max_parallel=max_parallel, kwargs...)
119121
end
120122

121123

@@ -149,6 +151,8 @@ function launch_on_machine(manager::SSHManager, machine, cnt, params, launched,
149151
dir = params[:dir]
150152
exename = params[:exename]
151153
exeflags = params[:exeflags]
154+
tunnel = params[:tunnel]
155+
multiplex = params[:multiplex]
152156

153157
# machine could be of the format [user@]host[:port] bind_addr[:bind_port]
154158
# machine format string is split on whitespace
@@ -178,6 +182,20 @@ function launch_on_machine(manager::SSHManager, machine, cnt, params, launched,
178182
end
179183
sshflags = `$(params[:sshflags]) $portopt`
180184

185+
if tunnel
186+
# First it checks if ssh multiplexing has been already enabled and the master process is running.
187+
# If it's already running, later ssh sessions also use the same ssh multiplexing session even if
188+
# `multiplex` is not explicitly specified; otherwise the tunneling session launched later won't
189+
# go to background and hang. This is because of OpenSSH implementation.
190+
if success(`ssh $sshflags -O check $host`)
191+
multiplex = true
192+
elseif multiplex
193+
# automatically create an SSH multiplexing session at the next SSH connection
194+
controlpath = "~/.ssh/julia-%r@%h:%p"
195+
sshflags = `$sshflags -o ControlMaster=auto -o ControlPath=$controlpath -o ControlPersist=no`
196+
end
197+
end
198+
181199
# Build up the ssh command
182200

183201
# the default worker timeout
@@ -211,7 +229,8 @@ function launch_on_machine(manager::SSHManager, machine, cnt, params, launched,
211229
wconfig = WorkerConfig()
212230
wconfig.io = io.out
213231
wconfig.host = host
214-
wconfig.tunnel = params[:tunnel]
232+
wconfig.tunnel = tunnel
233+
wconfig.multiplex = multiplex
215234
wconfig.sshflags = sshflags
216235
wconfig.exeflags = exeflags
217236
wconfig.exename = exename
@@ -256,25 +275,32 @@ end
256275

257276

258277
"""
259-
ssh_tunnel(user, host, bind_addr, port, sshflags) -> localport
278+
ssh_tunnel(user, host, bind_addr, port, sshflags, multiplex) -> localport
260279
261280
Establish an SSH tunnel to a remote worker.
262281
Return a port number `localport` such that `localhost:localport` connects to `host:port`.
263282
"""
264-
function ssh_tunnel(user, host, bind_addr, port, sshflags)
283+
function ssh_tunnel(user, host, bind_addr, port, sshflags, multiplex)
265284
port = Int(port)
266285
cnt = ntries = 100
267-
# if we cannot do port forwarding, bail immediately
286+
268287
# the connection is forwarded to `port` on the remote server over the local port `localport`
269-
# the -f option backgrounds the ssh session
270-
# `sleep 60` command specifies that an alloted time of 60 seconds is allowed to start the
271-
# remote julia process and establish the network connections specified by the process topology.
272-
# If no connections are made within 60 seconds, ssh will exit and an error will be printed on the
273-
# process that launched the remote process.
274-
ssh = `ssh -T -a -x -o ExitOnForwardFailure=yes`
275288
while cnt > 0
276289
localport = next_tunnel_port()
277-
if success(detach(`$ssh -f $sshflags $user@$host -L $localport:$bind_addr:$port sleep 60`))
290+
if multiplex
291+
# It assumes that an ssh multiplexing session has been already started by the remote worker.
292+
cmd = `ssh $sshflags -O forward -L $localport:$bind_addr:$port $user@$host`
293+
else
294+
# if we cannot do port forwarding, fail immediately
295+
# the -f option backgrounds the ssh session
296+
# `sleep 60` command specifies that an alloted time of 60 seconds is allowed to start the
297+
# remote julia process and establish the network connections specified by the process topology.
298+
# If no connections are made within 60 seconds, ssh will exit and an error will be printed on the
299+
# process that launched the remote process.
300+
ssh = `ssh -T -a -x -o ExitOnForwardFailure=yes`
301+
cmd = detach(`$ssh -f $sshflags $user@$host -L $localport:$bind_addr:$port sleep 60`)
302+
end
303+
if success(cmd)
278304
return localport
279305
end
280306
cnt -= 1
@@ -427,9 +453,11 @@ function connect(manager::ClusterManager, pid::Int, config::WorkerConfig)
427453
sem = tunnel_hosts_map[pubhost]
428454

429455
sshflags = notnothing(config.sshflags)
456+
multiplex = something(config.multiplex, false)
430457
acquire(sem)
431458
try
432-
(s, bind_addr) = connect_to_worker(pubhost, bind_addr, port, user, sshflags)
459+
(s, bind_addr, forward) = connect_to_worker_with_tunnel(pubhost, bind_addr, port, user, sshflags, multiplex)
460+
config.forward = forward
433461
finally
434462
release(sem)
435463
end
@@ -515,9 +543,23 @@ function connect_to_worker(host::AbstractString, port::Integer)
515543
end
516544

517545

518-
function connect_to_worker(host::AbstractString, bind_addr::AbstractString, port::Integer, tunnel_user::AbstractString, sshflags)
519-
s = connect("localhost", ssh_tunnel(tunnel_user, host, bind_addr, UInt16(port), sshflags))
520-
(s, bind_addr)
546+
function connect_to_worker_with_tunnel(host::AbstractString, bind_addr::AbstractString, port::Integer, tunnel_user::AbstractString, sshflags, multiplex)
547+
localport = ssh_tunnel(tunnel_user, host, bind_addr, UInt16(port), sshflags, multiplex)
548+
s = connect("localhost", localport)
549+
forward = "$localport:$bind_addr:$port"
550+
(s, bind_addr, forward)
551+
end
552+
553+
554+
function cancel_ssh_tunnel(config::WorkerConfig)
555+
host = notnothing(config.host)
556+
sshflags = notnothing(config.sshflags)
557+
tunnel = something(config.tunnel, false)
558+
multiplex = something(config.multiplex, false)
559+
if tunnel && multiplex
560+
forward = notnothing(config.forward)
561+
run(`ssh $sshflags -O cancel -L $forward $host`)
562+
end
521563
end
522564

523565

@@ -531,7 +573,12 @@ It should cause the remote worker specified by `pid` to exit.
531573
on `pid`.
532574
"""
533575
function kill(manager::ClusterManager, pid::Int, config::WorkerConfig)
534-
remote_do(exit, pid) # For TCP based transports this will result in a close of the socket
535-
# at our end, which will result in a cleanup of the worker.
576+
remote_do(exit, pid)
577+
nothing
578+
end
579+
580+
function kill(manager::SSHManager, pid::Int, config::WorkerConfig)
581+
remote_do(exit, pid)
582+
cancel_ssh_tunnel(config)
536583
nothing
537584
end

test/distributed_exec.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -720,6 +720,14 @@ if Sys.isunix() # aka have ssh
720720
@test length(new_pids) == num_workers
721721
test_n_remove_pids(new_pids)
722722

723+
print("\nssh addprocs with tunnel (SSH multiplexing)\n")
724+
new_pids = addprocs_with_testenv([("localhost", num_workers)]; tunnel=true, multiplex=true, sshflags=sshflags)
725+
@test length(new_pids) == num_workers
726+
controlpath = joinpath(homedir(), ".ssh", "julia-$(ENV["USER"])@localhost:22")
727+
@test issocket(controlpath)
728+
test_n_remove_pids(new_pids)
729+
@test :ok == timedwait(()->!issocket(controlpath), 10.0; pollint=0.5)
730+
723731
print("\nAll supported formats for hostname\n")
724732
h1 = "localhost"
725733
user = ENV["USER"]

0 commit comments

Comments
 (0)