@@ -113,7 +113,7 @@ impl TrackedProcMesh {
113
113
pub struct PyProcMesh {
114
114
inner : SharedCell < TrackedProcMesh > ,
115
115
keepalive : Keepalive ,
116
- proc_events : Arc < Mutex < ProcEvents > > ,
116
+ proc_events : SharedCell < Mutex < ProcEvents > > ,
117
117
stop_monitor_sender : mpsc:: Sender < bool > ,
118
118
user_monitor_registered : AtomicBool ,
119
119
}
@@ -159,9 +159,11 @@ impl PyProcMesh {
159
159
/// process on any proc failure.
160
160
fn monitored ( mut proc_mesh : ProcMesh , world_id : WorldId ) -> Self {
161
161
let ( sender, abort_receiver) = mpsc:: channel :: < bool > ( 1 ) ;
162
- let proc_events = Arc :: new ( Mutex :: new ( proc_mesh. events ( ) . unwrap ( ) ) ) ;
162
+ let proc_events = SharedCell :: from ( Mutex :: new ( proc_mesh. events ( ) . unwrap ( ) ) ) ;
163
163
let monitor = tokio:: spawn ( Self :: default_proc_mesh_monitor (
164
- proc_events. clone ( ) ,
164
+ proc_events
165
+ . borrow ( )
166
+ . expect ( "borrowing immediately after creation" ) ,
165
167
world_id,
166
168
abort_receiver,
167
169
) ) ;
@@ -177,7 +179,7 @@ impl PyProcMesh {
177
179
/// The default monitor of the proc mesh for crashes. If a proc crashes, we print the reason
178
180
/// to stderr and exit with code 1.
179
181
async fn default_proc_mesh_monitor (
180
- events : Arc < Mutex < ProcEvents > > ,
182
+ events : SharedCellRef < Mutex < ProcEvents > > ,
181
183
world_id : WorldId ,
182
184
mut abort_receiver : mpsc:: Receiver < bool > ,
183
185
) {
@@ -197,7 +199,12 @@ impl PyProcMesh {
197
199
}
198
200
}
199
201
}
200
- _ = abort_receiver. recv( ) => {
202
+ _ = async {
203
+ tokio:: select! {
204
+ _ = events. preempted( ) => ( ) ,
205
+ _ = abort_receiver. recv( ) => ( ) ,
206
+ }
207
+ } => {
201
208
// The default monitor is aborted, this happens when user takes over
202
209
// the monitoring responsibility.
203
210
eprintln!( "stop default supervision monitor for ProcMesh {}" , world_id) ;
@@ -320,6 +327,7 @@ impl PyProcMesh {
320
327
321
328
fn stop < ' py > ( & self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
322
329
let tracked_proc_mesh = self . inner . clone ( ) ;
330
+ let proc_events = self . proc_events . clone ( ) ;
323
331
pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
324
332
async {
325
333
// "Take" the proc mesh wrapper. Once we do, it should be impossible for new
@@ -333,6 +341,9 @@ impl PyProcMesh {
333
341
children. discard_all ( ) . await ?;
334
342
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
335
343
let _proc_mesh = proc_mesh. take ( ) . await ?;
344
+ // Grab the alloc back from `ProcEvents` and use that to stop the mesh.
345
+ let mut alloc = proc_events. take ( ) . await ?. into_inner ( ) . into_alloc ( ) ;
346
+ alloc. stop_and_wait ( ) . await ?;
336
347
anyhow:: Ok ( ( ) )
337
348
}
338
349
. await ?;
@@ -372,7 +383,7 @@ impl Drop for KeepaliveState {
372
383
module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
373
384
) ]
374
385
pub struct PyProcMeshMonitor {
375
- proc_events : Arc < Mutex < ProcEvents > > ,
386
+ proc_events : SharedCell < Mutex < ProcEvents > > ,
376
387
}
377
388
378
389
#[ pymethods]
@@ -384,13 +395,22 @@ impl PyProcMeshMonitor {
384
395
fn __anext__ ( & self , py : Python < ' _ > ) -> PyResult < PyObject > {
385
396
let events = self . proc_events . clone ( ) ;
386
397
Ok ( pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
398
+ let events = events
399
+ . borrow ( )
400
+ . map_err ( |_| PyRuntimeError :: new_err ( "`ProcEvents` is shutdown" ) ) ?;
387
401
let mut proc_events = events. lock ( ) . await ;
388
- let event: Option < _ > = proc_events. next ( ) . await ;
389
- match event {
390
- Some ( event) => Ok ( PyProcEvent :: from ( event) ) ,
391
- None => Err ( :: pyo3:: exceptions:: PyStopAsyncIteration :: new_err (
392
- "stop iteration" ,
393
- ) ) ,
402
+ tokio:: select! {
403
+ ( ) = events. preempted( ) => {
404
+ Err ( PyRuntimeError :: new_err( "shutting down `ProcEvents`" ) )
405
+ } ,
406
+ event = proc_events. next( ) => {
407
+ match event {
408
+ Some ( event) => Ok ( PyProcEvent :: from( event) ) ,
409
+ None => Err ( :: pyo3:: exceptions:: PyStopAsyncIteration :: new_err(
410
+ "stop iteration" ,
411
+ ) ) ,
412
+ }
413
+ }
394
414
}
395
415
} ) ?
396
416
. into ( ) )
0 commit comments