Skip to content

Commit a007830

Browse files
dulinrileyfacebook-github-bot
authored andcommitted
Add async context manager to ProcMesh for stop (#447)
Summary: Pull Request resolved: #447 Now that a ProcMesh has a `stop` function which closes the external processes, add `__aenter__` and `__aexit__` functions to be used as an async context manager. Reviewed By: pablorfb-meta Differential Revision: D77877888 fbshipit-source-id: ab2260725b4800f265160707a0a5885d4d2385f6
1 parent 2622236 commit a007830

File tree

2 files changed

+85
-0
lines changed

2 files changed

+85
-0
lines changed

python/monarch/proc_mesh.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def __init__(
8888
self._rsync_mesh_client: Optional[RsyncMeshClient] = None
8989
self._auto_reload_actor: Optional[AutoReloadActor] = None
9090
self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh
91+
self._stopped = False
9192
if _mock_shape is None:
9293
self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager)
9394

@@ -247,6 +248,33 @@ async def sync_workspace(self, auto_reload: bool = False) -> None:
247248

248249
async def stop(self) -> None:
249250
await self._proc_mesh.stop()
251+
self._stopped = True
252+
253+
async def __aenter__(self) -> "ProcMesh":
254+
if self._stopped:
255+
raise RuntimeError("`ProcMesh` has already been stopped")
256+
return self
257+
258+
async def __aexit__(
259+
self, exc_type: object, exc_val: object, exc_tb: object
260+
) -> None:
261+
# In case there are multiple nested "async with" statements, we only
262+
# want it to close once.
263+
if not self._stopped:
264+
await self.stop()
265+
266+
# Finalizer to check if the proc mesh was closed properly.
267+
def __del__(self) -> None:
268+
if not self._stopped:
269+
import warnings
270+
271+
warnings.warn(
272+
f"unstopped ProcMesh {self!r}",
273+
ResourceWarning,
274+
stacklevel=2,
275+
source=self,
276+
)
277+
# Cannot call stop here because it is async.
250278

251279

252280
async def local_proc_mesh_nonblocking(

python/tests/test_allocator.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,63 @@ async def test_stop_proc_mesh(self) -> None:
218218
# now we doing casting without accessing the wrapped type.
219219
del actor
220220

221+
async def test_stop_proc_mesh_context_manager(self) -> None:
222+
spec = AllocSpec(AllocConstraints(), host=2, gpu=4)
223+
224+
# create 2x process-allocators (on their own bind addresses) to simulate 2 hosts
225+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
226+
allocator = RemoteAllocator(
227+
world_id="test_remote_allocator",
228+
initializer=StaticRemoteAllocInitializer(host1, host2),
229+
heartbeat_interval=_100_MILLISECONDS,
230+
)
231+
alloc = await allocator.allocate(spec)
232+
proc_mesh = await ProcMesh.from_alloc(alloc)
233+
with self.assertRaises(ValueError, msg="foo"):
234+
async with proc_mesh:
235+
actor = await proc_mesh.spawn("test_actor", TestActor)
236+
# Ensure that proc mesh is stopped when context manager exits.
237+
raise ValueError("foo")
238+
239+
with self.assertRaises(
240+
RuntimeError, msg="`ProcMesh` has already been stopped"
241+
):
242+
await proc_mesh.spawn("test_actor", TestActor)
243+
244+
# TODO(agallagher): It'd be nice to test that this just fails
245+
# immediately, trying to access the wrapped actor mesh, but right
246+
# now we doing casting without accessing the wrapped type.
247+
del actor
248+
249+
async def test_stop_proc_mesh_context_manager_multiple_times(self) -> None:
250+
spec = AllocSpec(AllocConstraints(), host=2, gpu=4)
251+
252+
# create 2x process-allocators (on their own bind addresses) to simulate 2 hosts
253+
with remote_process_allocator() as host1, remote_process_allocator() as host2:
254+
allocator = RemoteAllocator(
255+
world_id="test_remote_allocator",
256+
initializer=StaticRemoteAllocInitializer(host1, host2),
257+
heartbeat_interval=_100_MILLISECONDS,
258+
)
259+
alloc = await allocator.allocate(spec)
260+
proc_mesh = await ProcMesh.from_alloc(alloc)
261+
# We can nest multiple context managers on the same mesh, the innermost
262+
# one closes the mesh and it cannot be used after that.
263+
async with proc_mesh:
264+
async with proc_mesh:
265+
actor = await proc_mesh.spawn("test_actor", TestActor)
266+
267+
with self.assertRaises(
268+
RuntimeError, msg="`ProcMesh` has already been stopped"
269+
):
270+
await proc_mesh.spawn("test_actor", TestActor)
271+
# Exiting a second time should not raise an error.
272+
273+
# TODO(agallagher): It'd be nice to test that this just fails
274+
# immediately, trying to access the wrapped actor mesh, but right
275+
# now we doing casting without accessing the wrapped type.
276+
del actor
277+
221278
async def test_stacked_1d_meshes(self) -> None:
222279
# create two stacked actor meshes on the same host
223280
# each actor mesh running on separate process-allocators

0 commit comments

Comments
 (0)