Skip to content

Commit 48a0714

Browse files
andrewjcgfacebook-github-bot
authored andcommitted
Support explicitly killing procs (#408)
Summary: Pull Request resolved: #408 Not super graceful yet, but adds some initial bookkeeping to support explicitly killing a proc mesh via reclaiming the alloc we put in `ProcEvents` and package in the `Keepalive` token. In particular, this adds `PyProcMesh.stop` method which consumes the keepalive token -- leaving the `PyProcMesh` in an effectively unusable state -- and uses it to stop the alloc. Reviewed By: vidhyav, mariusae Differential Revision: D77250211 fbshipit-source-id: 16b30c55cd558c9c52a188aa3364af9379bb6323
1 parent 4e196f4 commit 48a0714

File tree

4 files changed

+65
-12
lines changed

4 files changed

+65
-12
lines changed

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -526,6 +526,10 @@ impl ProcEvents {
526526
}
527527
}
528528
}
529+
530+
pub fn into_alloc(self) -> Box<dyn Alloc + Send + Sync> {
531+
self.event_state.alloc
532+
}
529533
}
530534

531535
/// Spawns from shared ([`Arc`]) proc meshes, providing [`ActorMesh`]es with

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ impl TrackedProcMesh {
113113
pub struct PyProcMesh {
114114
inner: SharedCell<TrackedProcMesh>,
115115
keepalive: Keepalive,
116-
proc_events: Arc<Mutex<ProcEvents>>,
116+
proc_events: SharedCell<Mutex<ProcEvents>>,
117117
stop_monitor_sender: mpsc::Sender<bool>,
118118
user_monitor_registered: AtomicBool,
119119
}
@@ -159,9 +159,11 @@ impl PyProcMesh {
159159
/// process on any proc failure.
160160
fn monitored(mut proc_mesh: ProcMesh, world_id: WorldId) -> Self {
161161
let (sender, abort_receiver) = mpsc::channel::<bool>(1);
162-
let proc_events = Arc::new(Mutex::new(proc_mesh.events().unwrap()));
162+
let proc_events = SharedCell::from(Mutex::new(proc_mesh.events().unwrap()));
163163
let monitor = tokio::spawn(Self::default_proc_mesh_monitor(
164-
proc_events.clone(),
164+
proc_events
165+
.borrow()
166+
.expect("borrowing immediately after creation"),
165167
world_id,
166168
abort_receiver,
167169
));
@@ -177,7 +179,7 @@ impl PyProcMesh {
177179
/// The default monitor of the proc mesh for crashes. If a proc crashes, we print the reason
178180
/// to stderr and exit with code 1.
179181
async fn default_proc_mesh_monitor(
180-
events: Arc<Mutex<ProcEvents>>,
182+
events: SharedCellRef<Mutex<ProcEvents>>,
181183
world_id: WorldId,
182184
mut abort_receiver: mpsc::Receiver<bool>,
183185
) {
@@ -197,7 +199,12 @@ impl PyProcMesh {
197199
}
198200
}
199201
}
200-
_ = abort_receiver.recv() => {
202+
_ = async {
203+
tokio::select! {
204+
_ = events.preempted() => (),
205+
_ = abort_receiver.recv() => (),
206+
}
207+
} => {
201208
// The default monitor is aborted, this happens when user takes over
202209
// the monitoring responsibility.
203210
eprintln!("stop default supervision monitor for ProcMesh {}", world_id);
@@ -320,6 +327,7 @@ impl PyProcMesh {
320327

321328
fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
322329
let tracked_proc_mesh = self.inner.clone();
330+
let proc_events = self.proc_events.clone();
323331
pyo3_async_runtimes::tokio::future_into_py(py, async move {
324332
async {
325333
// "Take" the proc mesh wrapper. Once we do, it should be impossible for new
@@ -333,6 +341,9 @@ impl PyProcMesh {
333341
children.discard_all().await?;
334342
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
335343
let _proc_mesh = proc_mesh.take().await?;
344+
// Grab the alloc back from `ProcEvents` and use that to stop the mesh.
345+
let mut alloc = proc_events.take().await?.into_inner().into_alloc();
346+
alloc.stop_and_wait().await?;
336347
anyhow::Ok(())
337348
}
338349
.await?;
@@ -372,7 +383,7 @@ impl Drop for KeepaliveState {
372383
module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh"
373384
)]
374385
pub struct PyProcMeshMonitor {
375-
proc_events: Arc<Mutex<ProcEvents>>,
386+
proc_events: SharedCell<Mutex<ProcEvents>>,
376387
}
377388

378389
#[pymethods]
@@ -384,13 +395,22 @@ impl PyProcMeshMonitor {
384395
fn __anext__(&self, py: Python<'_>) -> PyResult<PyObject> {
385396
let events = self.proc_events.clone();
386397
Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move {
398+
let events = events
399+
.borrow()
400+
.map_err(|_| PyRuntimeError::new_err("`ProcEvents` is shutdown"))?;
387401
let mut proc_events = events.lock().await;
388-
let event: Option<_> = proc_events.next().await;
389-
match event {
390-
Some(event) => Ok(PyProcEvent::from(event)),
391-
None => Err(::pyo3::exceptions::PyStopAsyncIteration::new_err(
392-
"stop iteration",
393-
)),
402+
tokio::select! {
403+
() = events.preempted() => {
404+
Err(PyRuntimeError::new_err("shutting down `ProcEvents`"))
405+
},
406+
event = proc_events.next() => {
407+
match event {
408+
Some(event) => Ok(PyProcEvent::from(event)),
409+
None => Err(::pyo3::exceptions::PyStopAsyncIteration::new_err(
410+
"stop iteration",
411+
)),
412+
}
413+
}
394414
}
395415
})?
396416
.into())

python/monarch/proc_mesh.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,9 @@ async def sync_workspace(self, auto_reload: bool = False) -> None:
245245
assert self._auto_reload_actor is not None
246246
await self._auto_reload_actor.reload.call()
247247

248+
async def stop(self) -> None:
249+
await self._proc_mesh.stop()
250+
248251

249252
async def local_proc_mesh_nonblocking(
250253
*, gpus: Optional[int] = None, hosts: int = 1

python/tests/test_allocator.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,32 @@ async def test_allocate_2d_mesh(self) -> None:
192192

193193
self.assert_computed_world_size(values, world_size)
194194

195+
async def test_stop_proc_mesh(self) -> None:
196+
spec = AllocSpec(AllocConstraints(), host=2, gpu=4)
197+
198+
# create 2x process-allocators (on their own bind addresses) to simulate 2 hosts
199+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
200+
allocator = RemoteAllocator(
201+
world_id="test_remote_allocator",
202+
initializer=StaticRemoteAllocInitializer(host1, host2),
203+
heartbeat_interval=_100_MILLISECONDS,
204+
)
205+
alloc = await allocator.allocate(spec)
206+
proc_mesh = await ProcMesh.from_alloc(alloc)
207+
actor = await proc_mesh.spawn("test_actor", TestActor)
208+
209+
await proc_mesh.stop()
210+
211+
with self.assertRaises(
212+
RuntimeError, msg="`ProcMesh` has already been stopped"
213+
):
214+
await proc_mesh.spawn("test_actor", TestActor)
215+
216+
# TODO(agallagher): It'd be nice to test that this just fails
217+
# immediately, trying to access the wrapped actor mesh, but right
218+
# now we doing casting without accessing the wrapped type.
219+
del actor
220+
195221
async def test_stacked_1d_meshes(self) -> None:
196222
# create two stacked actor meshes on the same host
197223
# each actor mesh running on separate process-allocators

0 commit comments

Comments
 (0)