Skip to content

Commit 4e196f4

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Support dropping inner ProcMesh/RootActorMesh from Python (#394)
Summary: Pull Request resolved: #394 This guards the `ProcMesh` we wrap inside `PyProcMesh` with a `SharedCell` and consumes it in `stop`, so that subsequent uses of the `PyProcMesh` object fail (on the `SharedCell::borrow()` call). Reviewed By: vidhyav, mariusae Differential Revision: D77250346 fbshipit-source-id: 10aa68e06890f880bc3cb4c4017cf773ba7e1b0c
1 parent 3bd98c8 commit 4e196f4

File tree

4 files changed

+65
-23
lines changed

4 files changed

+65
-23
lines changed

monarch_extension/src/code_sync.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#![allow(unsafe_op_in_unsafe_fn)]
1010

1111
use std::path::PathBuf;
12-
use std::sync::Arc;
1312

1413
use hyperactor_mesh::RootActorMesh;
1514
use hyperactor_mesh::SlicedActorMesh;
@@ -95,7 +94,7 @@ impl RsyncMeshClient {
9594
local_workspace: PathBuf,
9695
remote_workspace: PyWorkspaceLocation,
9796
) -> PyResult<Self> {
98-
let proc_mesh = Arc::clone(&proc_mesh.inner);
97+
let proc_mesh = proc_mesh.try_inner()?;
9998
let shape = shape.get_inner().clone();
10099
signal_safe_block_on(py, async move {
101100
let actor_mesh = proc_mesh

monarch_extension/src/mesh_controller.rs

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ static NEXT_ID: AtomicUsize = AtomicUsize::new(0);
155155
impl _Controller {
156156
#[new]
157157
fn new(py: Python, py_proc_mesh: &PyProcMesh) -> PyResult<Self> {
158-
let proc_mesh = py_proc_mesh.inner.as_ref();
158+
let proc_mesh = py_proc_mesh.try_inner()?;
159159
let id = NEXT_ID.fetch_add(1, atomic::Ordering::Relaxed);
160160
let controller_instance: InstanceWrapper<ControllerMessage> = InstanceWrapper::new(
161161
&PyProc::new_from_proc(proc_mesh.client_proc().clone()),
@@ -180,18 +180,17 @@ impl _Controller {
180180
controller_actor: controller_actor_ref,
181181
};
182182

183-
let py_proc_mesh = Arc::clone(&py_proc_mesh.inner);
183+
let py_proc_mesh = py_proc_mesh.try_inner()?;
184+
let shape = py_proc_mesh.shape().clone();
184185
let workers: anyhow::Result<SharedCell<RootActorMesh<'_, WorkerActor>>> =
185186
signal_safe_block_on(py, async move {
186187
let workers = py_proc_mesh
187-
.clone()
188188
.spawn(&format!("tensor_engine_workers_{}", id), &param)
189189
.await?;
190190
//workers.cast(ndslice::Selection::True, )?;
191-
workers.borrow()?.cast_slices(
192-
vec![py_proc_mesh.shape().slice().clone()],
193-
AssignRankMessage::AssignRank(),
194-
)?;
191+
workers
192+
.borrow()?
193+
.cast_slices(vec![shape.slice().clone()], AssignRankMessage::AssignRank())?;
195194
Ok(workers)
196195
})?;
197196
Ok(Self {

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 52 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,12 @@ use hyperactor_mesh::proc_mesh::ProcMesh;
2727
use hyperactor_mesh::proc_mesh::SharedSpawnable;
2828
use hyperactor_mesh::shared_cell::SharedCell;
2929
use hyperactor_mesh::shared_cell::SharedCellPool;
30+
use hyperactor_mesh::shared_cell::SharedCellRef;
3031
use monarch_types::PickledPyObject;
3132
use ndslice::Shape;
3233
use pyo3::IntoPyObjectExt;
3334
use pyo3::exceptions::PyException;
35+
use pyo3::exceptions::PyRuntimeError;
3436
use pyo3::prelude::*;
3537
use pyo3::pycell::PyRef;
3638
use pyo3::types::PyType;
@@ -44,7 +46,8 @@ use crate::shape::PyShape;
4446

4547
// A wrapper around `ProcMesh` which keeps track of all `RootActorMesh`s that it spawns.
4648
pub struct TrackedProcMesh {
47-
inner: Arc<ProcMesh>,
49+
inner: SharedCellRef<ProcMesh>,
50+
cell: SharedCell<ProcMesh>,
4851
children: SharedCellPool,
4952
}
5053

@@ -62,8 +65,11 @@ impl Display for TrackedProcMesh {
6265

6366
impl From<ProcMesh> for TrackedProcMesh {
6467
fn from(mesh: ProcMesh) -> Self {
68+
let cell = SharedCell::from(mesh);
69+
let inner = cell.borrow().unwrap();
6570
Self {
66-
inner: Arc::new(mesh),
71+
inner,
72+
cell,
6773
children: SharedCellPool::new(),
6874
}
6975
}
@@ -78,7 +84,7 @@ impl TrackedProcMesh {
7884
where
7985
A::Params: RemoteMessage,
8086
{
81-
let mesh = self.inner.clone();
87+
let mesh = self.cell.borrow()?;
8288
let actor = mesh.spawn(actor_name, params).await?;
8389
Ok(self.children.insert(actor))
8490
}
@@ -94,14 +100,18 @@ impl TrackedProcMesh {
94100
pub fn client_proc(&self) -> &Proc {
95101
self.inner.client_proc()
96102
}
103+
104+
pub fn into_inner(self) -> (SharedCell<ProcMesh>, SharedCellPool) {
105+
(self.cell, self.children)
106+
}
97107
}
98108

99109
#[pyclass(
100110
name = "ProcMesh",
101111
module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
102112
)]
103113
pub struct PyProcMesh {
104-
pub inner: Arc<TrackedProcMesh>,
114+
inner: SharedCell<TrackedProcMesh>,
105115
keepalive: Keepalive,
106116
proc_events: Arc<Mutex<ProcEvents>>,
107117
stop_monitor_sender: mpsc::Sender<bool>,
@@ -156,7 +166,7 @@ impl PyProcMesh {
156166
abort_receiver,
157167
));
158168
Self {
159-
inner: Arc::new(proc_mesh.into()),
169+
inner: SharedCell::from(TrackedProcMesh::from(proc_mesh)),
160170
keepalive: Keepalive::new(monitor),
161171
proc_events,
162172
stop_monitor_sender: sender,
@@ -196,6 +206,12 @@ impl PyProcMesh {
196206
}
197207
}
198208
}
209+
210+
pub fn try_inner(&self) -> PyResult<SharedCellRef<TrackedProcMesh>> {
211+
self.inner
212+
.borrow()
213+
.map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))
214+
}
199215
}
200216

201217
#[pymethods]
@@ -225,7 +241,7 @@ impl PyProcMesh {
225241
actor: &Bound<'py, PyType>,
226242
) -> PyResult<Bound<'py, PyAny>> {
227243
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
228-
let proc_mesh = Arc::clone(&self.inner);
244+
let proc_mesh = self.try_inner()?;
229245
let keepalive = self.keepalive.clone();
230246
pyo3_async_runtimes::tokio::future_into_py(py, async move {
231247
let mailbox = proc_mesh.client().clone();
@@ -246,7 +262,7 @@ impl PyProcMesh {
246262
actor: &Bound<'py, PyType>,
247263
) -> PyResult<PyObject> {
248264
let pickled_type = PickledPyObject::pickle(actor.as_any())?;
249-
let proc_mesh = Arc::clone(&self.inner);
265+
let proc_mesh = self.try_inner()?;
250266
let keepalive = self.keepalive.clone();
251267
signal_safe_block_on(py, async move {
252268
let mailbox = proc_mesh.client().clone();
@@ -287,19 +303,41 @@ impl PyProcMesh {
287303
}
288304

289305
#[getter]
290-
fn client(&self) -> PyMailbox {
291-
PyMailbox {
292-
inner: self.inner.client().clone(),
293-
}
306+
fn client(&self) -> PyResult<PyMailbox> {
307+
Ok(PyMailbox {
308+
inner: self.try_inner()?.client().clone(),
309+
})
294310
}
295311

296312
fn __repr__(&self) -> PyResult<String> {
297-
Ok(format!("<ProcMesh {}>", self.inner))
313+
Ok(format!("<ProcMesh {}>", *self.try_inner()?))
298314
}
299315

300316
#[getter]
301-
fn shape(&self) -> PyShape {
302-
self.inner.shape().clone().into()
317+
fn shape(&self) -> PyResult<PyShape> {
318+
Ok(self.try_inner()?.shape().clone().into())
319+
}
320+
321+
fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
322+
let tracked_proc_mesh = self.inner.clone();
323+
pyo3_async_runtimes::tokio::future_into_py(py, async move {
324+
async {
325+
// "Take" the proc mesh wrapper. Once we do, it should be impossible for new
326+
// actor meshes to be spawned.
327+
let (proc_mesh, children) = tracked_proc_mesh
328+
.take()
329+
.await
330+
.map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
331+
.into_inner();
332+
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
333+
children.discard_all().await?;
334+
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
335+
let _proc_mesh = proc_mesh.take().await?;
336+
anyhow::Ok(())
337+
}
338+
.await?;
339+
PyResult::Ok(())
340+
})
303341
}
304342
}
305343

python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ class ProcMesh:
8080
"""
8181
...
8282

83+
async def stop(self) -> None:
84+
"""
85+
Stop the proc mesh.
86+
"""
87+
...
88+
8389
def __repr__(self) -> str: ...
8490

8591
@final

0 commit comments

Comments
 (0)