Skip to content

Commit f06c9f1

Browse files
jayasifacebook-github-bot
authored andcommitted
Allow stop to be used as a blocking or non-blocking method (#546)
Summary: Pull Request resolved: #546 Allow the stop method to also be blocking. Reviewed By: highker Differential Revision: D78362041 fbshipit-source-id: 096d069e7fee7cf704255296dacecdfe91b2a36e
1 parent 0588650 commit f06c9f1

File tree

4 files changed

+82
-26
lines changed

4 files changed

+82
-26
lines changed

monarch_hyperactor/src/proc_mesh.rs

Lines changed: 43 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,32 @@ impl PyProcMesh {
227227
.borrow()
228228
.map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))
229229
}
230+
231+
async fn stop_mesh(
232+
inner: SharedCell<TrackedProcMesh>,
233+
proc_events: SharedCell<Mutex<ProcEvents>>,
234+
) -> Result<(), anyhow::Error> {
235+
// "Take" the proc mesh wrapper. Once we do, it should be impossible for new
236+
// actor meshes to be spawned.
237+
let tracked_proc_mesh = inner.take().await.map_err(|e| {
238+
PyRuntimeError::new_err(format!("`ProcMesh` has already been stopped: {}", e))
239+
})?;
240+
let (proc_mesh, children) = tracked_proc_mesh.into_inner();
241+
242+
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
243+
children.discard_all().await?;
244+
245+
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
246+
let _proc_mesh = proc_mesh.take().await?;
247+
248+
// Grab the alloc back from `ProcEvents` and use that to stop the mesh.
249+
let proc_events_taken = proc_events.take().await?;
250+
let mut alloc = proc_events_taken.into_inner().into_alloc();
251+
252+
alloc.stop_and_wait().await?;
253+
254+
anyhow::Ok(())
255+
}
230256
}
231257

232258
#[pymethods]
@@ -350,32 +376,28 @@ impl PyProcMesh {
350376
Ok(self.try_inner()?.shape().clone().into())
351377
}
352378

353-
fn stop<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
354-
let tracked_proc_mesh = self.inner.clone();
379+
fn stop_nonblocking<'py>(&self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
380+
// Clone the necessary fields from self to avoid capturing self in the async block
381+
let inner = self.inner.clone();
355382
let proc_events = self.proc_events.clone();
383+
356384
pyo3_async_runtimes::tokio::future_into_py(py, async move {
357-
async {
358-
// "Take" the proc mesh wrapper. Once we do, it should be impossible for new
359-
// actor meshes to be spawned.
360-
let (proc_mesh, children) = tracked_proc_mesh
361-
.take()
362-
.await
363-
.map_err(|_| PyRuntimeError::new_err("`ProcMesh` has already been stopped"))?
364-
.into_inner();
365-
// Now we discard all in-flight actor meshes. After this, the `ProcMesh` should be "unused".
366-
children.discard_all().await?;
367-
// Finally, take ownership of the inner proc mesh, which will allowing dropping it.
368-
let _proc_mesh = proc_mesh.take().await?;
369-
// Grab the alloc back from `ProcEvents` and use that to stop the mesh.
370-
let mut alloc = proc_events.take().await?.into_inner().into_alloc();
371-
alloc.stop_and_wait().await?;
372-
373-
anyhow::Ok(())
374-
}
375-
.await?;
385+
Self::stop_mesh(inner, proc_events).await?;
376386
PyResult::Ok(())
377387
})
378388
}
389+
390+
fn stop_blocking<'py>(&self, py: Python<'py>) -> PyResult<()> {
391+
// Clone the necessary fields from self to avoid capturing self in the async block
392+
let inner = self.inner.clone();
393+
let proc_events = self.proc_events.clone();
394+
395+
signal_safe_block_on(py, async move {
396+
Self::stop_mesh(inner, proc_events)
397+
.await
398+
.map_err(|e| PyRuntimeError::new_err(format!("{}", e)))
399+
})?
400+
}
379401
}
380402

381403
/// A keepalive token that aborts a task only after the last clone

python/monarch/_rust_bindings/monarch_hyperactor/proc_mesh.pyi

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,12 +81,18 @@ class ProcMesh:
8181
"""
8282
...
8383

84-
async def stop(self) -> None:
84+
async def stop_nonblocking(self) -> None:
8585
"""
8686
Stop the proc mesh.
8787
"""
8888
...
8989

90+
def stop_blocking(self) -> None:
91+
"""
92+
Stop the proc mesh. Blocks until the mesh is fully stopped.
93+
"""
94+
...
95+
9096
def __repr__(self) -> str: ...
9197

9298
@final

python/monarch/_src/actor/proc_mesh.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,15 +285,25 @@ def logging_option(self, stream_to_client: bool = False) -> None:
285285
)
286286
self._logging_mesh_client.set_mode(stream_to_client)
287287

288-
async def stop(self) -> None:
289-
await self._proc_mesh.stop()
290-
self._stopped = True
291-
292288
async def __aenter__(self) -> "ProcMesh":
293289
if self._stopped:
294290
raise RuntimeError("`ProcMesh` has already been stopped")
295291
return self
296292

293+
def stop(self) -> Future[None]:
294+
async def _stop_nonblocking() -> None:
295+
await self._proc_mesh.stop_nonblocking()
296+
self._stopped = True
297+
298+
def _stop_blocking() -> None:
299+
self._proc_mesh.stop_blocking()
300+
self._stopped = True
301+
302+
return Future(
303+
lambda: _stop_nonblocking(),
304+
lambda: _stop_blocking(),
305+
)
306+
297307
async def __aexit__(
298308
self, exc_type: object, exc_val: object, exc_tb: object
299309
) -> None:

python/tests/test_allocator.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,24 @@ async def test_allocate_2d_mesh(self) -> None:
204204

205205
self.assert_computed_world_size(values, world_size)
206206

207+
async def test_stop_proc_mesh_blocking(self) -> None:
208+
spec = AllocSpec(AllocConstraints(), host=2, gpu=4)
209+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
210+
allocator = RemoteAllocator(
211+
world_id="test_remote_allocator",
212+
initializer=StaticRemoteAllocInitializer(host1, host2),
213+
heartbeat_interval=_100_MILLISECONDS,
214+
)
215+
alloc = await allocator.allocate(spec)
216+
proc_mesh = await ProcMesh.from_alloc(alloc)
217+
actor = proc_mesh.spawn("test_actor", TestActor).get()
218+
proc_mesh.stop().get()
219+
with self.assertRaises(
220+
RuntimeError, msg="`ProcMesh` has already been stopped"
221+
):
222+
proc_mesh.spawn("test_actor", TestActor).get()
223+
del actor
224+
207225
async def test_stop_proc_mesh(self) -> None:
208226
spec = AllocSpec(AllocConstraints(), host=2, gpu=4)
209227

0 commit comments

Comments
 (0)