Skip to content

Commit 8bfb112

Browse files
thomasywangfacebook-github-bot
authored andcommitted
Send SIGTERM to remote process alloc children (#399)
Summary: Pull Request resolved: #399 - Kill child processes with SIGTERM - Setup signal listener for SIGTERM and exit gracefully on SIGTERM Reviewed By: dulinriley Differential Revision: D77614649 fbshipit-source-id: 3657c131187b09c2017ba43d9c796d0079d36f41
1 parent e52e3e4 commit 8bfb112

File tree

3 files changed

+132
-47
lines changed

3 files changed

+132
-47
lines changed

hyperactor_mesh/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@ rand = { version = "0.8", features = ["small_rng"] }
4747
serde = { version = "1.0.185", features = ["derive", "rc"] }
4848
serde_bytes = "0.11"
4949
serde_json = { version = "1.0.140", features = ["alloc", "float_roundtrip", "unbounded_depth"] }
50+
signal-hook = "0.3"
51+
signal-hook-tokio = { version = "0.3", features = ["futures-v0_3"] }
5052
tempfile = "3.15"
5153
thiserror = "2.0.12"
5254
tokio = { version = "1.45.0", features = ["full", "test-util", "tracing"] }

hyperactor_mesh/src/alloc/process.rs

Lines changed: 59 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ use hyperactor::sync::flag;
3434
use hyperactor::sync::monitor;
3535
use hyperactor_state::state_actor::StateActor;
3636
use ndslice::Shape;
37+
use nix::sys::signal;
38+
use nix::unistd::Pid;
3739
use tokio::io;
3840
use tokio::process::Command;
3941
use tokio::sync::Mutex;
@@ -204,17 +206,15 @@ impl Child {
204206
let monitor = async move {
205207
let reason = tokio::select! {
206208
_ = handle => {
207-
match process.kill().await {
208-
Err(e) => {
209-
tracing::error!("error killing process: {}", e);
210-
// In this cased, we're left with little choice but to
211-
// orphan the process.
212-
ProcStopReason::Unknown
213-
},
214-
Ok(_) => {
215-
Self::exit_status_to_reason(process.wait().await)
216-
}
217-
}
209+
let Some(id) = process.id() else {
210+
tracing::error!("could not get child process id");
211+
return ProcStopReason::Unknown;
212+
};
213+
if let Err(e) = signal::kill(Pid::from_raw(id as i32), signal::SIGTERM) {
214+
tracing::error!("failed to kill child process: {}", e);
215+
return ProcStopReason::Unknown;
216+
};
217+
Self::exit_status_to_reason(process.wait().await)
218218
}
219219
result = process.wait() => Self::exit_status_to_reason(result),
220220
};
@@ -301,6 +301,11 @@ impl Child {
301301
self.stop(ProcStopReason::Watchdog);
302302
}
303303
}
304+
305+
#[cfg(test)]
306+
fn fail_group(&self) {
307+
self.group.fail();
308+
}
304309
}
305310

306311
impl ProcessAlloc {
@@ -516,4 +521,47 @@ mod tests {
516521
crate::alloc_test_suite!(ProcessAllocator::new(Command::new(
517522
buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap()
518523
)));
524+
525+
#[tokio::test]
526+
async fn test_sigterm_on_group_fail() {
527+
let bootstrap_binary = buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap();
528+
let mut allocator = ProcessAllocator::new(Command::new(bootstrap_binary));
529+
530+
let mut alloc = allocator
531+
.allocate(AllocSpec {
532+
shape: ndslice::shape! { replica = 1 },
533+
constraints: Default::default(),
534+
})
535+
.await
536+
.unwrap();
537+
538+
let proc_id = {
539+
loop {
540+
match alloc.next().await {
541+
Some(ProcState::Running { proc_id, .. }) => {
542+
break proc_id;
543+
}
544+
Some(ProcState::Failed { description, .. }) => {
545+
panic!("Process allocation failed: {}", description);
546+
}
547+
Some(_other) => {}
548+
None => {
549+
panic!("Allocation ended unexpectedly");
550+
}
551+
}
552+
}
553+
};
554+
555+
if let Some(child) = alloc.active.get(&proc_id.rank()) {
556+
child.fail_group();
557+
}
558+
559+
assert!(matches!(
560+
alloc.next().await,
561+
Some(ProcState::Stopped {
562+
reason: ProcStopReason::Killed(15, false),
563+
..
564+
})
565+
));
566+
}
519567
}

hyperactor_mesh/src/bootstrap.rs

Lines changed: 71 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
use std::time::Duration;
1010

11+
use futures::StreamExt;
1112
use hyperactor::ActorRef;
1213
use hyperactor::Named;
1314
use hyperactor::ProcId;
@@ -21,6 +22,7 @@ use hyperactor::clock::RealClock;
2122
use hyperactor::mailbox::MailboxServer;
2223
use serde::Deserialize;
2324
use serde::Serialize;
25+
use signal_hook::consts::signal::SIGTERM;
2426

2527
use crate::proc_mesh::mesh_agent::MeshAgent;
2628

@@ -119,6 +121,8 @@ async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: Channe
119121
/// Use [`bootstrap_or_die`] to implement this behavior directly.
120122
pub async fn bootstrap() -> anyhow::Error {
121123
pub async fn go() -> Result<(), anyhow::Error> {
124+
let mut signals = signal_hook_tokio::Signals::new([SIGTERM])?;
125+
122126
let bootstrap_addr: ChannelAddr = std::env::var(BOOTSTRAP_ADDR_ENV)
123127
.map_err(|err| anyhow::anyhow!("read `{}`: {}", BOOTSTRAP_ADDR_ENV, err))?
124128
.parse()?;
@@ -141,45 +145,76 @@ pub async fn bootstrap() -> anyhow::Error {
141145

142146
loop {
143147
let _ = hyperactor::tracing::info_span!("wait_for_next_message_from_mesh_agent");
144-
match rx.recv().await? {
145-
Allocator2Process::StartProc(proc_id, listen_transport) => {
146-
let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
147-
let (proc_addr, proc_rx) =
148-
channel::serve(ChannelAddr::any(listen_transport)).await?;
149-
// Undeliverable messages get forwarded to the mesh agent.
150-
let handle = proc.clone().serve(proc_rx, mesh_agent.port());
151-
drop(handle); // linter appeasement; it is safe to drop this future
152-
tx.send(Process2Allocator(
153-
bootstrap_index,
154-
Process2AllocatorMessage::StartedProc(
155-
proc_id.clone(),
156-
mesh_agent.bind(),
157-
proc_addr,
158-
),
159-
))
160-
.await?;
161-
procs.push(proc);
162-
}
163-
Allocator2Process::StopAndExit(code) => {
164-
tracing::info!("stopping procs with code {code}");
165-
for mut proc_to_stop in procs {
166-
if let Err(err) = proc_to_stop
167-
.destroy_and_wait(Duration::from_millis(10), None)
168-
.await
169-
{
170-
tracing::error!(
171-
"error while stopping proc {}: {}",
172-
proc_to_stop.proc_id(),
173-
err
174-
);
148+
tokio::select! {
149+
msg = rx.recv() => {
150+
match msg? {
151+
Allocator2Process::StartProc(proc_id, listen_transport) => {
152+
let (proc, mesh_agent) = MeshAgent::bootstrap(proc_id.clone()).await?;
153+
let (proc_addr, proc_rx) =
154+
channel::serve(ChannelAddr::any(listen_transport)).await?;
155+
// Undeliverable messages get forwarded to the mesh agent.
156+
let handle = proc.clone().serve(proc_rx, mesh_agent.port());
157+
drop(handle); // linter appeasement; it is safe to drop this future
158+
tx.send(Process2Allocator(
159+
bootstrap_index,
160+
Process2AllocatorMessage::StartedProc(
161+
proc_id.clone(),
162+
mesh_agent.bind(),
163+
proc_addr,
164+
),
165+
))
166+
.await?;
167+
procs.push(proc);
168+
}
169+
Allocator2Process::StopAndExit(code) => {
170+
tracing::info!("stopping procs with code {code}");
171+
for mut proc_to_stop in procs {
172+
if let Err(err) = proc_to_stop
173+
.destroy_and_wait(Duration::from_millis(10), None)
174+
.await
175+
{
176+
tracing::error!(
177+
"error while stopping proc {}: {}",
178+
proc_to_stop.proc_id(),
179+
err
180+
);
181+
}
182+
}
183+
tracing::info!("exiting with {code}");
184+
std::process::exit(code);
185+
}
186+
Allocator2Process::Exit(code) => {
187+
tracing::info!("exiting with {code}");
188+
std::process::exit(code);
175189
}
176190
}
177-
tracing::info!("exiting with {code}");
178-
std::process::exit(code);
179191
}
180-
Allocator2Process::Exit(code) => {
181-
tracing::info!("exiting with {code}");
182-
std::process::exit(code);
192+
signal = signals.next() => {
193+
if signal.is_some_and(|sig| sig == SIGTERM) {
194+
tracing::info!("received SIGTERM, stopping procs");
195+
for mut proc_to_stop in procs {
196+
if let Err(err) = proc_to_stop
197+
.destroy_and_wait(Duration::from_millis(10), None)
198+
.await
199+
{
200+
tracing::error!(
201+
"error while stopping proc {}: {}",
202+
proc_to_stop.proc_id(),
203+
err
204+
);
205+
}
206+
}
207+
// SAFETY: We're setting the handle to SigDfl (defautl system behaviour)
208+
if let Err(err) = unsafe {
209+
nix::sys::signal::signal(nix::sys::signal::SIGTERM, nix::sys::signal::SigHandler::SigDfl)
210+
} {
211+
tracing::error!("failed to signal SIGTERM: {}", err);
212+
}
213+
if let Err(err) = nix::sys::signal::raise(nix::sys::signal::SIGTERM) {
214+
tracing::error!("failed to raise SIGTERM: {}", err);
215+
}
216+
std::process::exit(128 + SIGTERM);
217+
}
183218
}
184219
}
185220
}

0 commit comments

Comments
 (0)