@@ -27,10 +27,12 @@ use hyperactor_mesh::proc_mesh::ProcMesh;
27
27
use hyperactor_mesh:: proc_mesh:: SharedSpawnable ;
28
28
use hyperactor_mesh:: shared_cell:: SharedCell ;
29
29
use hyperactor_mesh:: shared_cell:: SharedCellPool ;
30
+ use hyperactor_mesh:: shared_cell:: SharedCellRef ;
30
31
use monarch_types:: PickledPyObject ;
31
32
use ndslice:: Shape ;
32
33
use pyo3:: IntoPyObjectExt ;
33
34
use pyo3:: exceptions:: PyException ;
35
+ use pyo3:: exceptions:: PyRuntimeError ;
34
36
use pyo3:: prelude:: * ;
35
37
use pyo3:: pycell:: PyRef ;
36
38
use pyo3:: types:: PyType ;
@@ -44,7 +46,8 @@ use crate::shape::PyShape;
44
46
45
47
// A wrapper around `ProcMesh` which keeps track of all `RootActorMesh`s that it spawns.
46
48
pub struct TrackedProcMesh {
47
- inner : Arc < ProcMesh > ,
49
+ inner : SharedCellRef < ProcMesh > ,
50
+ cell : SharedCell < ProcMesh > ,
48
51
children : SharedCellPool ,
49
52
}
50
53
@@ -62,8 +65,11 @@ impl Display for TrackedProcMesh {
62
65
63
66
impl From < ProcMesh > for TrackedProcMesh {
64
67
fn from ( mesh : ProcMesh ) -> Self {
68
+ let cell = SharedCell :: from ( mesh) ;
69
+ let inner = cell. borrow ( ) . unwrap ( ) ;
65
70
Self {
66
- inner : Arc :: new ( mesh) ,
71
+ inner,
72
+ cell,
67
73
children : SharedCellPool :: new ( ) ,
68
74
}
69
75
}
@@ -78,7 +84,7 @@ impl TrackedProcMesh {
78
84
where
79
85
A :: Params : RemoteMessage ,
80
86
{
81
- let mesh = self . inner . clone ( ) ;
87
+ let mesh = self . cell . borrow ( ) ? ;
82
88
let actor = mesh. spawn ( actor_name, params) . await ?;
83
89
Ok ( self . children . insert ( actor) )
84
90
}
@@ -94,14 +100,18 @@ impl TrackedProcMesh {
94
100
pub fn client_proc ( & self ) -> & Proc {
95
101
self . inner . client_proc ( )
96
102
}
103
+
104
+ pub fn into_inner ( self ) -> ( SharedCell < ProcMesh > , SharedCellPool ) {
105
+ ( self . cell , self . children )
106
+ }
97
107
}
98
108
99
109
#[ pyclass(
100
110
name = "ProcMesh" ,
101
111
module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
102
112
) ]
103
113
pub struct PyProcMesh {
104
- pub inner : Arc < TrackedProcMesh > ,
114
+ inner : SharedCell < TrackedProcMesh > ,
105
115
keepalive : Keepalive ,
106
116
proc_events : Arc < Mutex < ProcEvents > > ,
107
117
stop_monitor_sender : mpsc:: Sender < bool > ,
@@ -156,7 +166,7 @@ impl PyProcMesh {
156
166
abort_receiver,
157
167
) ) ;
158
168
Self {
159
- inner : Arc :: new ( proc_mesh . into ( ) ) ,
169
+ inner : SharedCell :: from ( TrackedProcMesh :: from ( proc_mesh ) ) ,
160
170
keepalive : Keepalive :: new ( monitor) ,
161
171
proc_events,
162
172
stop_monitor_sender : sender,
@@ -196,6 +206,12 @@ impl PyProcMesh {
196
206
}
197
207
}
198
208
}
209
+
210
+ pub fn try_inner ( & self ) -> PyResult < SharedCellRef < TrackedProcMesh > > {
211
+ self . inner
212
+ . borrow ( )
213
+ . map_err ( |_| PyRuntimeError :: new_err ( "`ProcMesh` has already been stopped" ) )
214
+ }
199
215
}
200
216
201
217
#[ pymethods]
@@ -225,7 +241,7 @@ impl PyProcMesh {
225
241
actor : & Bound < ' py , PyType > ,
226
242
) -> PyResult < Bound < ' py , PyAny > > {
227
243
let pickled_type = PickledPyObject :: pickle ( actor. as_any ( ) ) ?;
228
- let proc_mesh = Arc :: clone ( & self . inner ) ;
244
+ let proc_mesh = self . try_inner ( ) ? ;
229
245
let keepalive = self . keepalive . clone ( ) ;
230
246
pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
231
247
let mailbox = proc_mesh. client ( ) . clone ( ) ;
@@ -246,7 +262,7 @@ impl PyProcMesh {
246
262
actor : & Bound < ' py , PyType > ,
247
263
) -> PyResult < PyObject > {
248
264
let pickled_type = PickledPyObject :: pickle ( actor. as_any ( ) ) ?;
249
- let proc_mesh = Arc :: clone ( & self . inner ) ;
265
+ let proc_mesh = self . try_inner ( ) ? ;
250
266
let keepalive = self . keepalive . clone ( ) ;
251
267
signal_safe_block_on ( py, async move {
252
268
let mailbox = proc_mesh. client ( ) . clone ( ) ;
@@ -287,19 +303,41 @@ impl PyProcMesh {
287
303
}
288
304
289
305
#[ getter]
290
- fn client ( & self ) -> PyMailbox {
291
- PyMailbox {
292
- inner : self . inner . client ( ) . clone ( ) ,
293
- }
306
+ fn client ( & self ) -> PyResult < PyMailbox > {
307
+ Ok ( PyMailbox {
308
+ inner : self . try_inner ( ) ? . client ( ) . clone ( ) ,
309
+ } )
294
310
}
295
311
296
312
fn __repr__ ( & self ) -> PyResult < String > {
297
- Ok ( format ! ( "<ProcMesh {}>" , self . inner ) )
313
+ Ok ( format ! ( "<ProcMesh {}>" , * self . try_inner ( ) ? ) )
298
314
}
299
315
300
316
#[ getter]
301
- fn shape ( & self ) -> PyShape {
302
- self . inner . shape ( ) . clone ( ) . into ( )
317
+ fn shape ( & self ) -> PyResult < PyShape > {
318
+ Ok ( self . try_inner ( ) ?. shape ( ) . clone ( ) . into ( ) )
319
+ }
320
+
321
+ fn stop < ' py > ( & self , py : Python < ' py > ) -> PyResult < Bound < ' py , PyAny > > {
322
+ let tracked_proc_mesh = self . inner . clone ( ) ;
323
+ pyo3_async_runtimes:: tokio:: future_into_py ( py, async move {
324
+ async {
325
+ // "Take" the proc mesh wrapper. Once we do, it should be impossible for new
326
+ // actor meshes to be spawned.
327
+ let ( proc_mesh, children) = tracked_proc_mesh
328
+ . take ( )
329
+ . await
330
+ . map_err ( |_| PyRuntimeError :: new_err ( "`ProcMesh` has already been stopped" ) ) ?
331
+ . into_inner ( ) ;
332
+ // Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
333
+ children. discard_all ( ) . await ?;
334
+ // Finally, take ownership of the inner proc mesh, which will allowing dropping it.
335
+ let _proc_mesh = proc_mesh. take ( ) . await ?;
336
+ anyhow:: Ok ( ( ) )
337
+ }
338
+ . await ?;
339
+ PyResult :: Ok ( ( ) )
340
+ } )
303
341
}
304
342
}
305
343
0 commit comments