Skip to content

Commit c2d1088

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Track all RootActorMeshs spawned by PyProcMesh (#392)
Summary: Pull Request resolved: #392 This diff adds a small wrapper around the `ProcMesh` wrapped by `PyProcMesh`, to bookkeep all outstanding references to `RootActorMesh`s that get spawned (the bookkeeping is automatically GC'd when outstanding references are dropped). The spawned `RootActorMesh`s are also now wrapped by `SharedCell` to allow the `ProcMesh` to dispose of them on close, and to require callers to perform a fallible "borrow" on access, which is used by subsequent diffs to allow for a graceful cleanup/drop of a `ProcMesh` while still allowing outstanding-but-invalid references in Python code. Reviewed By: mariusae Differential Revision: D77385189 fbshipit-source-id: 2dd2900c74c1d49351b3e6bf54ae64fa3303c38a
1 parent 8ccf12e commit c2d1088

File tree

4 files changed

+105
-22
lines changed

4 files changed

+105
-22
lines changed

monarch_extension/src/code_sync.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@ use hyperactor_mesh::RootActorMesh;
1515
use hyperactor_mesh::SlicedActorMesh;
1616
use hyperactor_mesh::code_sync::WorkspaceLocation;
1717
use hyperactor_mesh::code_sync::rsync;
18-
use hyperactor_mesh::proc_mesh::SharedSpawnable;
1918
use hyperactor_mesh::shape::Shape;
19+
use hyperactor_mesh::shared_cell::SharedCell;
2020
use monarch_hyperactor::proc_mesh::PyProcMesh;
2121
use monarch_hyperactor::runtime::signal_safe_block_on;
2222
use monarch_hyperactor::shape::PyShape;
@@ -79,7 +79,7 @@ impl PyWorkspaceLocation {
7979
module = "monarch._rust_bindings.monarch_extension.code_sync"
8080
)]
8181
pub struct RsyncMeshClient {
82-
actor_mesh: Arc<RootActorMesh<'static, rsync::RsyncActor>>,
82+
actor_mesh: SharedCell<RootActorMesh<'static, rsync::RsyncActor>>,
8383
shape: Shape,
8484
workspace: PathBuf,
8585
}
@@ -107,7 +107,7 @@ impl RsyncMeshClient {
107107
)
108108
.await?;
109109
Ok(Self {
110-
actor_mesh: Arc::new(actor_mesh),
110+
actor_mesh,
111111
shape,
112112
workspace: local_workspace,
113113
})
@@ -116,7 +116,7 @@ impl RsyncMeshClient {
116116

117117
fn sync_workspace<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
118118
let workspace = self.workspace.clone();
119-
let inner_mesh = self.actor_mesh.clone();
119+
let inner_mesh = self.actor_mesh.borrow().map_err(anyhow::Error::msg)?;
120120
let shape = self.shape.clone();
121121
pyo3_async_runtimes::tokio::future_into_py(py, async move {
122122
let mesh = SlicedActorMesh::new(&inner_mesh, shape);

monarch_extension/src/mesh_controller.rs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use hyperactor::ActorRef;
2020
use hyperactor::data::Serialized;
2121
use hyperactor_mesh::actor_mesh::ActorMesh;
2222
use hyperactor_mesh::actor_mesh::RootActorMesh;
23-
use hyperactor_mesh::proc_mesh::SharedSpawnable;
23+
use hyperactor_mesh::shared_cell::SharedCell;
2424
use monarch_hyperactor::ndslice::PySlice;
2525
use monarch_hyperactor::proc::InstanceWrapper;
2626
use monarch_hyperactor::proc::PyActorId;
@@ -53,7 +53,7 @@ use crate::convert::convert;
5353
)]
5454
struct _Controller {
5555
controller_instance: Arc<Mutex<InstanceWrapper<ControllerMessage>>>,
56-
workers: RootActorMesh<'static, WorkerActor>,
56+
workers: SharedCell<RootActorMesh<'static, WorkerActor>>,
5757
pending_messages: VecDeque<PyObject>,
5858
history: History,
5959
}
@@ -132,6 +132,8 @@ impl _Controller {
132132
}
133133
fn send_slice(&mut self, slice: Slice, message: WorkerMessage) -> PyResult<()> {
134134
self.workers
135+
.borrow()
136+
.map_err(anyhow::Error::msg)?
135137
.cast_slices(vec![slice], message)
136138
.map_err(|err| PyErr::new::<PyValueError, _>(err.to_string()))
137139
// let shape = Shape::new(
@@ -179,13 +181,13 @@ impl _Controller {
179181
};
180182

181183
let py_proc_mesh = Arc::clone(&py_proc_mesh.inner);
182-
let workers: anyhow::Result<RootActorMesh<'_, WorkerActor>> =
184+
let workers: anyhow::Result<SharedCell<RootActorMesh<'_, WorkerActor>>> =
183185
signal_safe_block_on(py, async move {
184186
let workers = py_proc_mesh
185187
.spawn(&format!("tensor_engine_workers_{}", id), &param)
186188
.await?;
187189
//workers.cast(ndslice::Selection::True, )?;
188-
workers.cast_slices(
190+
workers.borrow()?.cast_slices(
189191
vec![py_proc_mesh.shape().slice().clone()],
190192
AssignRankMessage::AssignRank(),
191193
)?;
@@ -274,7 +276,13 @@ impl _Controller {
274276
}
275277
fn _drain_and_stop(&mut self, py: Python<'_>) -> PyResult<()> {
276278
self.send_slice(
277-
self.workers.proc_mesh().shape().slice().clone(),
279+
self.workers
280+
.borrow()
281+
.map_err(anyhow::Error::msg)?
282+
.proc_mesh()
283+
.shape()
284+
.slice()
285+
.clone(),
278286
WorkerMessage::Exit { error: None },
279287
)?;
280288
let instance = self.controller_instance.clone();

monarch_hyperactor/src/actor_mesh.rs

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,14 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
use std::sync::Arc;
10-
119
use hyperactor::ActorRef;
1210
use hyperactor_mesh::Mesh;
1311
use hyperactor_mesh::RootActorMesh;
1412
use hyperactor_mesh::actor_mesh::ActorMesh;
13+
use hyperactor_mesh::shared_cell::SharedCell;
14+
use hyperactor_mesh::shared_cell::SharedCellRef;
1515
use pyo3::exceptions::PyException;
16+
use pyo3::exceptions::PyRuntimeError;
1617
use pyo3::prelude::*;
1718

1819
use crate::actor::PythonActor;
@@ -27,28 +28,37 @@ use crate::shape::PyShape;
2728
module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh"
2829
)]
2930
pub struct PythonActorMesh {
30-
pub(super) inner: Arc<RootActorMesh<'static, PythonActor>>,
31+
pub(super) inner: SharedCell<RootActorMesh<'static, PythonActor>>,
3132
pub(super) client: PyMailbox,
3233
pub(super) _keepalive: Keepalive,
3334
}
3435

36+
impl PythonActorMesh {
37+
fn try_inner(&self) -> PyResult<SharedCellRef<RootActorMesh<'static, PythonActor>>> {
38+
self.inner
39+
.borrow()
40+
.map_err(|_| PyRuntimeError::new_err("`PythonActorMesh` has already been stopped"))
41+
}
42+
}
43+
3544
#[pymethods]
3645
impl PythonActorMesh {
3746
fn cast(&self, message: &PythonMessage) -> PyResult<()> {
3847
use ndslice::selection::dsl::*;
39-
self.inner
48+
self.try_inner()?
4049
.cast(all(true_()), message.clone())
4150
.map_err(|err| PyException::new_err(err.to_string()))?;
4251
Ok(())
4352
}
4453

4554
// Consider defining a "PythonActorRef", which carries specifically
4655
// a reference to python message actors.
47-
fn get(&self, rank: usize) -> Option<PyActorId> {
48-
self.inner
56+
fn get(&self, rank: usize) -> PyResult<Option<PyActorId>> {
57+
Ok(self
58+
.try_inner()?
4959
.get(rank)
5060
.map(ActorRef::into_actor_id)
51-
.map(PyActorId::from)
61+
.map(PyActorId::from))
5262
}
5363

5464
#[getter]
@@ -57,8 +67,8 @@ impl PythonActorMesh {
5767
}
5868

5969
#[getter]
60-
fn shape(&self) -> PyShape {
61-
PyShape::from(self.inner.shape().clone())
70+
fn shape(&self) -> PyResult<PyShape> {
71+
Ok(PyShape::from(self.try_inner()?.shape().clone()))
6272
}
6373
}
6474
pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> {

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 69 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,18 +6,29 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9+
use std::fmt::Debug;
10+
use std::fmt::Display;
911
use std::sync::Arc;
1012
use std::sync::atomic::AtomicBool;
1113

14+
use hyperactor::Actor;
15+
use hyperactor::Mailbox;
16+
use hyperactor::RemoteMessage;
1217
use hyperactor::WorldId;
18+
use hyperactor::actor::RemoteActor;
19+
use hyperactor::proc::Proc;
1320
use hyperactor_extension::alloc::PyAlloc;
21+
use hyperactor_mesh::RootActorMesh;
1422
use hyperactor_mesh::alloc::Alloc;
1523
use hyperactor_mesh::alloc::ProcStopReason;
1624
use hyperactor_mesh::proc_mesh::ProcEvent;
1725
use hyperactor_mesh::proc_mesh::ProcEvents;
1826
use hyperactor_mesh::proc_mesh::ProcMesh;
1927
use hyperactor_mesh::proc_mesh::SharedSpawnable;
28+
use hyperactor_mesh::shared_cell::SharedCell;
29+
use hyperactor_mesh::shared_cell::SharedCellPool;
2030
use monarch_types::PickledPyObject;
31+
use ndslice::Shape;
2132
use pyo3::IntoPyObjectExt;
2233
use pyo3::exceptions::PyException;
2334
use pyo3::prelude::*;
@@ -31,12 +42,66 @@ use crate::mailbox::PyMailbox;
3142
use crate::runtime::signal_safe_block_on;
3243
use crate::shape::PyShape;
3344

45+
// A wrapper around `ProcMesh` which keeps track of all `RootActorMesh`s that it spawns.
46+
pub struct TrackedProcMesh {
47+
inner: Arc<ProcMesh>,
48+
children: SharedCellPool,
49+
}
50+
51+
impl Debug for TrackedProcMesh {
52+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53+
Debug::fmt(&*self.inner, f)
54+
}
55+
}
56+
57+
impl Display for TrackedProcMesh {
58+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59+
Display::fmt(&*self.inner, f)
60+
}
61+
}
62+
63+
impl From<ProcMesh> for TrackedProcMesh {
64+
fn from(mesh: ProcMesh) -> Self {
65+
Self {
66+
inner: Arc::new(mesh),
67+
children: SharedCellPool::new(),
68+
}
69+
}
70+
}
71+
72+
impl TrackedProcMesh {
73+
pub async fn spawn<A: Actor + RemoteActor>(
74+
&self,
75+
actor_name: &str,
76+
params: &A::Params,
77+
) -> Result<SharedCell<RootActorMesh<'static, A>>, anyhow::Error>
78+
where
79+
A::Params: RemoteMessage,
80+
{
81+
let mesh = self.inner.clone();
82+
let actor = mesh.spawn(actor_name, params).await?;
83+
Ok(self.children.insert(actor))
84+
}
85+
86+
pub fn client(&self) -> &Mailbox {
87+
self.inner.client()
88+
}
89+
90+
pub fn shape(&self) -> &Shape {
91+
self.inner.shape()
92+
}
93+
94+
pub fn client_proc(&self) -> &Proc {
95+
self.inner.client_proc()
96+
}
97+
}
98+
3499
#[pyclass(
35100
name = "ProcMesh",
36101
module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
37102
)]
38103
pub struct PyProcMesh {
39-
pub inner: Arc<ProcMesh>,
104+
pub inner: Arc<TrackedProcMesh>,
40105
keepalive: Keepalive,
41106
proc_events: Arc<Mutex<ProcEvents>>,
42107
stop_monitor_sender: mpsc::Sender<bool>,
@@ -91,7 +156,7 @@ impl PyProcMesh {
91156
abort_receiver,
92157
));
93158
Self {
94-
inner: Arc::new(proc_mesh),
159+
inner: Arc::new(proc_mesh.into()),
95160
keepalive: Keepalive::new(monitor),
96161
proc_events,
97162
stop_monitor_sender: sender,
@@ -165,7 +230,7 @@ impl PyProcMesh {
165230
pyo3_async_runtimes::tokio::future_into_py(py, async move {
166231
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
167232
let python_actor_mesh = PythonActorMesh {
168-
inner: Arc::new(actor_mesh),
233+
inner: actor_mesh,
169234
client: PyMailbox {
170235
inner: proc_mesh.client().clone(),
171236
},
@@ -187,7 +252,7 @@ impl PyProcMesh {
187252
signal_safe_block_on(py, async move {
188253
let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?;
189254
let python_actor_mesh = PythonActorMesh {
190-
inner: Arc::new(actor_mesh),
255+
inner: actor_mesh,
191256
client: PyMailbox {
192257
inner: proc_mesh.client().clone(),
193258
},

0 commit comments

Comments
 (0)