Skip to content

Commit 7e71b3e

Browse files
suofacebook-github-bot
authored andcommitted
heartbeat in child processes to ensure cleanup (#435)
Summary: Pull Request resolved: #435 In D77392241, I removed the signal-based cleanup mechanism. The assumption there was that the `rx.recv()` in `bootstrap.rs` would throw an error if the other side hung up, and the error would bubble up and eventually abort the child process. This assumption is wrong; `rx.recv()` has server-like semantics, not channel-like (which makes sense; there could be many senders, so any individual one disappearing should not abort the receiver!). But as a result, we were not cleaning up properly if the parent process exited, child processes would just hang around forever. The test added in D77348271 *happened* to pass, because we sent `SIGKILL` to the parent process, triggering unclean shutdown which *will* cause a error on the receiver side. However, a *graceful* shutdown (e.g. from an uncaught Python exception) will not. A quick solution is to add a simple heartbeat task and kill the process if it fails. ghstack-source-id: 294670138 exported-using-ghexport Reviewed By: ahmadsharif1 Differential Revision: D77802426 fbshipit-source-id: 0b7f7b76aa528a22b3921d897c9b04d20914cc69
1 parent 7ee3afb commit 7e71b3e

File tree

4 files changed

+188
-25
lines changed

4 files changed

+188
-25
lines changed

hyperactor_mesh/src/alloc/process.rs

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -331,12 +331,6 @@ impl ProcessAlloc {
331331
cmd.stdout(Stdio::piped());
332332
cmd.stderr(Stdio::piped());
333333

334-
// Opt-in to signal handling (`PR_SET_PDEATHSIG`) so that the
335-
// spawned subprocess will automatically exit when the parent
336-
// process dies.
337-
// TODO: Use hyperactor::config::global::MANAGED_SUBPROCESS_ENV once it's defined
338-
cmd.env("HYPERACTOR_MANAGED_SUBPROCESS", "1");
339-
340334
let proc_id = ProcId(WorldId(self.name.to_string()), index);
341335
tracing::debug!("Spawning process {:?}", cmd);
342336
match cmd.spawn() {
@@ -430,6 +424,9 @@ impl Alloc for ProcessAlloc {
430424
addr,
431425
});
432426
}
427+
Process2AllocatorMessage::Heartbeat => {
428+
tracing::debug!("recv heartbeat from {index}");
429+
}
433430
}
434431
},
435432

hyperactor_mesh/src/bootstrap.rs

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ use hyperactor::channel::ChannelAddr;
1616
use hyperactor::channel::ChannelTransport;
1717
use hyperactor::channel::Rx;
1818
use hyperactor::channel::Tx;
19+
use hyperactor::clock::Clock;
20+
use hyperactor::clock::RealClock;
1921
use hyperactor::mailbox::MailboxServer;
2022
use serde::Deserialize;
2123
use serde::Serialize;
@@ -44,6 +46,8 @@ pub(crate) enum Process2AllocatorMessage {
4446
/// after instruction by the allocator through the corresponding
4547
/// [`Allocator2Process`] message.
4648
StartedProc(ProcId, ActorRef<MeshAgent>, ChannelAddr),
49+
50+
Heartbeat,
4751
}
4852

4953
/// Messages sent from the allocator to a process.
@@ -62,6 +66,43 @@ pub(crate) enum Allocator2Process {
6266
Exit(i32),
6367
}
6468

69+
async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: ChannelAddr) {
70+
let tx = match channel::dial(bootstrap_addr.clone()) {
71+
Ok(tx) => tx,
72+
73+
Err(err) => {
74+
tracing::error!(
75+
"Failed to establish heartbeat connection to allocator, exiting! (addr: {:?}): {}",
76+
bootstrap_addr,
77+
err
78+
);
79+
std::process::exit(1);
80+
}
81+
};
82+
tracing::info!(
83+
"Heartbeat connection established to allocator (idx: {bootstrap_index}, addr: {bootstrap_addr:?})",
84+
);
85+
loop {
86+
RealClock.sleep(Duration::from_secs(5)).await;
87+
88+
let result = tx
89+
.send(Process2Allocator(
90+
bootstrap_index,
91+
Process2AllocatorMessage::Heartbeat,
92+
))
93+
.await;
94+
95+
if let Err(err) = result {
96+
tracing::error!(
97+
"Heartbeat failed to allocator, exiting! (addr: {:?}): {}",
98+
bootstrap_addr,
99+
err
100+
);
101+
std::process::exit(1);
102+
}
103+
}
104+
}
105+
65106
/// Entry point to processes managed by hyperactor_mesh. This advertises the process
66107
/// to a bootstrap server, and receives instructions to manage the lifecycle(s) of
67108
/// procs within this process.
@@ -86,15 +127,15 @@ pub async fn bootstrap() -> anyhow::Error {
86127
.parse()?;
87128
let listen_addr = ChannelAddr::any(bootstrap_addr.transport());
88129
let (serve_addr, mut rx) = channel::serve(listen_addr).await?;
89-
let tx = channel::dial(bootstrap_addr)?;
130+
let tx = channel::dial(bootstrap_addr.clone())?;
90131

91-
{
92-
tx.send(Process2Allocator(
93-
bootstrap_index,
94-
Process2AllocatorMessage::Hello(serve_addr),
95-
))
96-
.await?;
97-
}
132+
tx.send(Process2Allocator(
133+
bootstrap_index,
134+
Process2AllocatorMessage::Hello(serve_addr),
135+
))
136+
.await?;
137+
138+
tokio::spawn(exit_if_missed_heartbeat(bootstrap_index, bootstrap_addr));
98139

99140
let mut procs = Vec::new();
100141

python/tests/error_test_binary.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ async def await_then_error(self) -> None:
4848
await asyncio.sleep(0.1)
4949
raise RuntimeError("oh noez")
5050

51+
@endpoint
52+
async def get_pid(self) -> int:
53+
"""Endpoint that returns the process PID."""
54+
import os
55+
56+
return os.getpid()
57+
5158

5259
class ErrorActorSync(Actor):
5360
"""An actor that has endpoints cause segfaults."""
@@ -79,8 +86,7 @@ def _run_error_test_sync(num_procs, sync_endpoint, endpoint_name):
7986
error_actor = proc.spawn("error_actor", actor_class).get()
8087

8188
# This output is checked in the test to make sure that the process actually got here
82-
print("I actually ran")
83-
sys.stdout.flush()
89+
print("Started function error_test", flush=True)
8490

8591
if endpoint_name == "cause_segfault":
8692
endpoint = error_actor.cause_segfault
@@ -110,8 +116,7 @@ async def run_test():
110116
error_actor = await proc.spawn("error_actor", actor_class)
111117

112118
# This output is checked in the test to make sure that the process actually got here
113-
print("I actually ran")
114-
sys.stdout.flush()
119+
print("Started function error_test", flush=True)
115120

116121
if endpoint_name == "cause_segfault":
117122
endpoint = error_actor.cause_segfault
@@ -153,15 +158,13 @@ def error_endpoint(num_procs, sync_test_impl, sync_endpoint, endpoint_name):
153158

154159
@main.command("error-bootstrap")
155160
def error_bootstrap():
156-
print("I actually ran")
157-
sys.stdout.flush()
161+
print("Started function error_bootstrap", flush=True)
158162

159163
proc_mesh(gpus=4, env={"MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING": "1"}).get()
160164

161165

162166
async def _error_unmonitored():
163-
print("I actually ran")
164-
sys.stdout.flush()
167+
print("Started function _error_unmonitored", flush=True)
165168

166169
proc = await proc_mesh(gpus=1)
167170
actor = await proc.spawn("error_actor", ErrorActor)
@@ -204,5 +207,41 @@ def error_unmonitored():
204207
asyncio.run(_error_unmonitored())
205208

206209

210+
async def _error_cleanup():
211+
"""Test function that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
212+
print("Started function _error_cleanup() for parent process", flush=True)
213+
214+
# Spawn an 8 process procmesh
215+
proc = await proc_mesh(gpus=8)
216+
error_actor = await proc.spawn("error_actor", ErrorActor)
217+
218+
print("Procmesh spawned, collecting child PIDs from actors", flush=True)
219+
220+
# Get PIDs from all actor processes
221+
try:
222+
# Call get_pid endpoint on all actors to collect their PIDs
223+
pids = await error_actor.get_pid.call()
224+
child_pids = [str(pid) for _, pid in pids]
225+
print(f"CHILD_PIDS: {','.join(child_pids)}", flush=True)
226+
except Exception as e:
227+
print(f"Error getting child PIDs from actors: {e}", flush=True)
228+
229+
print("About to call endpoint that raises exception", flush=True)
230+
231+
# Call an endpoint that raises a normal exception
232+
try:
233+
await error_actor.await_then_error.call()
234+
except Exception as e:
235+
print(f"Expected exception caught: {e}", flush=True)
236+
# Re-raise to cause the process to exit with non-zero code
237+
raise
238+
239+
240+
@main.command("error-cleanup")
241+
def error_cleanup():
242+
"""Command that spawns an 8 process procmesh and calls an endpoint that returns a normal exception."""
243+
asyncio.run(_error_cleanup())
244+
245+
207246
if __name__ == "__main__":
208247
main()

python/tests/test_actor_error.py

Lines changed: 89 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_na
140140
raise
141141

142142
# Assert that the subprocess exited with a non-zero code
143-
assert "I actually ran" in process.stdout.decode()
143+
assert "Started function error_test" in process.stdout.decode()
144144
assert (
145145
process.returncode != 0
146146
), f"Expected non-zero exit code, got {process.returncode}"
@@ -170,7 +170,7 @@ def test_proc_mesh_bootstrap_error():
170170
raise
171171

172172
# Assert that the subprocess exited with a non-zero code
173-
assert "I actually ran" in process.stdout.decode()
173+
assert "Started function error_bootstrap" in process.stdout.decode()
174174
assert (
175175
process.returncode != 0
176176
), f"Expected non-zero exit code, got {process.returncode}"
@@ -234,12 +234,98 @@ async def test_exception_after_wait_unmonitored():
234234
raise
235235

236236
# Assert that the subprocess exited with a non-zero code
237-
assert "I actually ran" in process.stdout.decode()
237+
assert "Started function _error_unmonitored" in process.stdout.decode()
238238
assert (
239239
process.returncode != 0
240240
), f"Expected non-zero exit code, got {process.returncode}"
241241

242242

243+
# oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited
244+
@pytest.mark.oss_skip
245+
def test_python_actor_process_cleanup():
246+
"""
247+
Test that PythonActor processes are cleaned up when the parent process dies.
248+
249+
This test spawns an 8 process procmesh and calls an endpoint that returns a normal exception,
250+
then verifies that all spawned processes have been cleaned up after the spawned binary dies.
251+
"""
252+
import os
253+
import signal
254+
import time
255+
256+
# Run the error-cleanup test in a subprocess
257+
test_bin = importlib.resources.files("monarch.python.tests").joinpath("test_bin")
258+
cmd = [
259+
str(test_bin),
260+
"error-cleanup",
261+
]
262+
263+
try:
264+
print("running cmd", " ".join(cmd))
265+
process = subprocess.run(cmd, capture_output=True, timeout=180, text=True)
266+
except subprocess.TimeoutExpired as e:
267+
print("timeout expired")
268+
if e.stdout is not None:
269+
print(e.stdout.decode())
270+
if e.stderr is not None:
271+
print(e.stderr.decode())
272+
raise
273+
274+
# Read stdout line by line to get child PIDs
275+
assert "Started function _error_cleanup() for parent process" in process.stdout
276+
277+
child_pids = set()
278+
for line in process.stdout.splitlines():
279+
if line.startswith("CHILD_PIDS: "):
280+
pids_str = line[len("CHILD_PIDS: ") :] # noqa
281+
child_pids = {
282+
int(pid.strip()) for pid in pids_str.split(",") if pid.strip()
283+
}
284+
print(f"Extracted child PIDs: {child_pids}")
285+
break
286+
287+
if not child_pids:
288+
raise AssertionError("No child PIDs found in output")
289+
290+
assert child_pids, "No child PIDs were collected from subprocess output"
291+
292+
# Wait for child processes to be cleaned up
293+
print("Waiting for child processes to be cleaned up...")
294+
cleanup_timeout = 120
295+
start_time = time.time()
296+
297+
def is_process_running(pid):
298+
"""Check if a process with the given PID is still running."""
299+
try:
300+
os.kill(pid, 0) # Signal 0 doesn't kill, just checks if process exists
301+
return True
302+
except OSError:
303+
return False
304+
305+
still_running = set(child_pids)
306+
307+
while time.time() - start_time < cleanup_timeout:
308+
if not still_running:
309+
print("All child processes have been cleaned up!")
310+
return
311+
312+
still_running = {pid for pid in still_running if is_process_running(pid)}
313+
314+
print(f"Still running child PIDs: {still_running}")
315+
time.sleep(2)
316+
317+
# If we get here, some processes are still running
318+
# Try to clean up remaining processes
319+
for pid in still_running:
320+
try:
321+
os.kill(pid, signal.SIGKILL)
322+
except OSError:
323+
pass
324+
raise AssertionError(
325+
f"Child processes not cleaned up after {cleanup_timeout}s: {still_running}"
326+
)
327+
328+
243329
class ErrorActor(Actor):
244330
def __init__(self, message):
245331
raise RuntimeError("fail on init")

0 commit comments

Comments
 (0)