8
8
9
9
use std:: time:: Duration ;
10
10
11
+ use futures:: StreamExt ;
11
12
use hyperactor:: ActorRef ;
12
13
use hyperactor:: Named ;
13
14
use hyperactor:: ProcId ;
@@ -21,6 +22,7 @@ use hyperactor::clock::RealClock;
21
22
use hyperactor:: mailbox:: MailboxServer ;
22
23
use serde:: Deserialize ;
23
24
use serde:: Serialize ;
25
+ use signal_hook:: consts:: signal:: SIGTERM ;
24
26
25
27
use crate :: proc_mesh:: mesh_agent:: MeshAgent ;
26
28
@@ -119,6 +121,8 @@ async fn exit_if_missed_heartbeat(bootstrap_index: usize, bootstrap_addr: Channe
119
121
/// Use [`bootstrap_or_die`] to implement this behavior directly.
120
122
pub async fn bootstrap ( ) -> anyhow:: Error {
121
123
pub async fn go ( ) -> Result < ( ) , anyhow:: Error > {
124
+ let mut signals = signal_hook_tokio:: Signals :: new ( [ SIGTERM ] ) ?;
125
+
122
126
let bootstrap_addr: ChannelAddr = std:: env:: var ( BOOTSTRAP_ADDR_ENV )
123
127
. map_err ( |err| anyhow:: anyhow!( "read `{}`: {}" , BOOTSTRAP_ADDR_ENV , err) ) ?
124
128
. parse ( ) ?;
@@ -141,45 +145,76 @@ pub async fn bootstrap() -> anyhow::Error {
141
145
142
146
loop {
143
147
let _ = hyperactor:: tracing:: info_span!( "wait_for_next_message_from_mesh_agent" ) ;
144
- match rx. recv ( ) . await ? {
145
- Allocator2Process :: StartProc ( proc_id, listen_transport) => {
146
- let ( proc, mesh_agent) = MeshAgent :: bootstrap ( proc_id. clone ( ) ) . await ?;
147
- let ( proc_addr, proc_rx) =
148
- channel:: serve ( ChannelAddr :: any ( listen_transport) ) . await ?;
149
- // Undeliverable messages get forwarded to the mesh agent.
150
- let handle = proc. clone ( ) . serve ( proc_rx, mesh_agent. port ( ) ) ;
151
- drop ( handle) ; // linter appeasement; it is safe to drop this future
152
- tx. send ( Process2Allocator (
153
- bootstrap_index,
154
- Process2AllocatorMessage :: StartedProc (
155
- proc_id. clone ( ) ,
156
- mesh_agent. bind ( ) ,
157
- proc_addr,
158
- ) ,
159
- ) )
160
- . await ?;
161
- procs. push ( proc) ;
162
- }
163
- Allocator2Process :: StopAndExit ( code) => {
164
- tracing:: info!( "stopping procs with code {code}" ) ;
165
- for mut proc_to_stop in procs {
166
- if let Err ( err) = proc_to_stop
167
- . destroy_and_wait ( Duration :: from_millis ( 10 ) , None )
168
- . await
169
- {
170
- tracing:: error!(
171
- "error while stopping proc {}: {}" ,
172
- proc_to_stop. proc_id( ) ,
173
- err
174
- ) ;
148
+ tokio:: select! {
149
+ msg = rx. recv( ) => {
150
+ match msg? {
151
+ Allocator2Process :: StartProc ( proc_id, listen_transport) => {
152
+ let ( proc, mesh_agent) = MeshAgent :: bootstrap( proc_id. clone( ) ) . await ?;
153
+ let ( proc_addr, proc_rx) =
154
+ channel:: serve( ChannelAddr :: any( listen_transport) ) . await ?;
155
+ // Undeliverable messages get forwarded to the mesh agent.
156
+ let handle = proc. clone( ) . serve( proc_rx, mesh_agent. port( ) ) ;
157
+ drop( handle) ; // linter appeasement; it is safe to drop this future
158
+ tx. send( Process2Allocator (
159
+ bootstrap_index,
160
+ Process2AllocatorMessage :: StartedProc (
161
+ proc_id. clone( ) ,
162
+ mesh_agent. bind( ) ,
163
+ proc_addr,
164
+ ) ,
165
+ ) )
166
+ . await ?;
167
+ procs. push( proc) ;
168
+ }
169
+ Allocator2Process :: StopAndExit ( code) => {
170
+ tracing:: info!( "stopping procs with code {code}" ) ;
171
+ for mut proc_to_stop in procs {
172
+ if let Err ( err) = proc_to_stop
173
+ . destroy_and_wait( Duration :: from_millis( 10 ) , None )
174
+ . await
175
+ {
176
+ tracing:: error!(
177
+ "error while stopping proc {}: {}" ,
178
+ proc_to_stop. proc_id( ) ,
179
+ err
180
+ ) ;
181
+ }
182
+ }
183
+ tracing:: info!( "exiting with {code}" ) ;
184
+ std:: process:: exit( code) ;
185
+ }
186
+ Allocator2Process :: Exit ( code) => {
187
+ tracing:: info!( "exiting with {code}" ) ;
188
+ std:: process:: exit( code) ;
175
189
}
176
190
}
177
- tracing:: info!( "exiting with {code}" ) ;
178
- std:: process:: exit ( code) ;
179
191
}
180
- Allocator2Process :: Exit ( code) => {
181
- tracing:: info!( "exiting with {code}" ) ;
182
- std:: process:: exit ( code) ;
192
+ signal = signals. next( ) => {
193
+ if signal. is_some_and( |sig| sig == SIGTERM ) {
194
+ tracing:: info!( "received SIGTERM, stopping procs" ) ;
195
+ for mut proc_to_stop in procs {
196
+ if let Err ( err) = proc_to_stop
197
+ . destroy_and_wait( Duration :: from_millis( 10 ) , None )
198
+ . await
199
+ {
200
+ tracing:: error!(
201
+ "error while stopping proc {}: {}" ,
202
+ proc_to_stop. proc_id( ) ,
203
+ err
204
+ ) ;
205
+ }
206
+ }
207
+ // SAFETY: We're setting the handle to SigDfl (defautl system behaviour)
208
+ if let Err ( err) = unsafe {
209
+ nix:: sys:: signal:: signal( nix:: sys:: signal:: SIGTERM , nix:: sys:: signal:: SigHandler :: SigDfl )
210
+ } {
211
+ tracing:: error!( "failed to signal SIGTERM: {}" , err) ;
212
+ }
213
+ if let Err ( err) = nix:: sys:: signal:: raise( nix:: sys:: signal:: SIGTERM ) {
214
+ tracing:: error!( "failed to raise SIGTERM: {}" , err) ;
215
+ }
216
+ std:: process:: exit( 128 + SIGTERM ) ;
217
+ }
183
218
}
184
219
}
185
220
}
0 commit comments