From 5abbc77c359f2f16db61a94b9ea9ac2347122352 Mon Sep 17 00:00:00 2001 From: Peng Zhang Date: Tue, 15 Jul 2025 07:52:56 -0700 Subject: [PATCH] Add PythonActorMeshRef (#531) Summary: `PythonActorMeshRef` is used to bind Rust `struct ActorMeshRef` to the python side. Reviewed By: shayne-fletcher Differential Revision: D78284705 --- hyperactor_mesh/src/actor_mesh.rs | 1 - hyperactor_mesh/src/reference.rs | 36 +----- monarch_hyperactor/src/actor_mesh.rs | 77 ++++++++++++ monarch_hyperactor/src/shape.rs | 9 ++ .../monarch_hyperactor/actor_mesh.pyi | 29 +++++ .../monarch_hyperactor/bootstrap.pyi | 2 + .../monarch_hyperactor/selection.pyi | 2 + .../monarch_hyperactor/shape.pyi | 9 +- .../monarch_hyperactor/telemetry.pyi | 2 + python/tests/_monarch/test_actor_mesh.py | 110 ++++++++++++++++++ 10 files changed, 242 insertions(+), 35 deletions(-) create mode 100644 python/tests/_monarch/test_actor_mesh.py diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 81ca8bae..da1d8c49 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -145,7 +145,6 @@ pub trait ActorMesh: Mesh { self.name().to_string(), ), self.shape().clone(), - self.proc_mesh().shape().clone(), self.proc_mesh().comm_actor().clone(), ) } diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 83c7f795..625b50ce 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -75,8 +75,6 @@ pub struct ActorMeshId(pub ProcMeshId, pub String); pub struct ActorMeshRef { pub(crate) mesh_id: ActorMeshId, shape: Shape, - /// The shape of the underlying Proc Mesh. - proc_mesh_shape: Shape, /// The reference to the comm actor of the underlying Proc Mesh. comm_actor_ref: ActorRef, phantom: PhantomData, @@ -90,13 +88,11 @@ impl ActorMeshRef { pub(crate) fn attest( mesh_id: ActorMeshId, shape: Shape, - proc_mesh_shape: Shape, comm_actor_ref: ActorRef, ) -> Self { Self { mesh_id, shape, - proc_mesh_shape, comm_actor_ref, phantom: PhantomData, } @@ -142,7 +138,6 @@ impl Clone for ActorMeshRef { Self { mesh_id: self.mesh_id.clone(), shape: self.shape.clone(), - proc_mesh_shape: self.proc_mesh_shape.clone(), comm_actor_ref: self.comm_actor_ref.clone(), phantom: PhantomData, } @@ -161,12 +156,11 @@ impl Eq for ActorMeshRef {} mod tests { use async_trait::async_trait; use hyperactor::Actor; + use hyperactor::Bind; use hyperactor::Context; use hyperactor::Handler; use hyperactor::PortRef; - use hyperactor::message::Bind; - use hyperactor::message::Bindings; - use hyperactor::message::Unbind; + use hyperactor::Unbind; use hyperactor_mesh_macros::sel; use ndslice::shape; @@ -183,11 +177,11 @@ mod tests { shape! { replica = 4 } } - #[derive(Debug, Serialize, Deserialize, Named, Clone)] + #[derive(Debug, Serialize, Deserialize, Named, Clone, Bind, Unbind)] struct MeshPingPongMessage( /*ttl:*/ u64, ActorMeshRef, - /*completed port:*/ PortRef, + /*completed port:*/ #[binding(include)] PortRef, ); #[derive(Debug, Clone)] @@ -203,7 +197,6 @@ mod tests { struct MeshPingPongActorParams { mesh_id: ActorMeshId, shape: Shape, - proc_mesh_shape: Shape, comm_actor_ref: ActorRef, } @@ -213,12 +206,7 @@ mod tests { async fn new(params: Self::Params) -> Result { Ok(Self { - mesh_ref: ActorMeshRef::attest( - params.mesh_id, - params.shape, - params.proc_mesh_shape, - params.comm_actor_ref, - ), + mesh_ref: ActorMeshRef::attest(params.mesh_id, params.shape, params.comm_actor_ref), }) } } @@ -240,18 +228,6 @@ mod tests { } } - impl Unbind for MeshPingPongMessage { - fn unbind(&self, bindings: &mut Bindings) -> anyhow::Result<()> { - self.2.unbind(bindings) - } - } - - impl Bind for MeshPingPongMessage { - fn bind(&mut self, bindings: &mut Bindings) -> anyhow::Result<()> { - self.2.bind(bindings) - } - } - #[tokio::test] async fn test_inter_mesh_ping_pong() { let alloc_ping = LocalAllocator @@ -278,7 +254,6 @@ mod tests { "ping".to_string(), ), shape: ping_proc_mesh.shape().clone(), - proc_mesh_shape: ping_proc_mesh.shape().clone(), comm_actor_ref: ping_proc_mesh.comm_actor().clone(), }, ) @@ -296,7 +271,6 @@ mod tests { "pong".to_string(), ), shape: pong_proc_mesh.shape().clone(), - proc_mesh_shape: pong_proc_mesh.shape().clone(), comm_actor_ref: pong_proc_mesh.comm_actor().clone(), }, ) diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index f5d63923..65ceabfa 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -17,13 +17,18 @@ use hyperactor_mesh::Mesh; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; use hyperactor_mesh::actor_mesh::ActorSupervisionEvents; +use hyperactor_mesh::reference::ActorMeshRef; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; use pyo3::exceptions::PyEOFError; use pyo3::exceptions::PyException; +use pyo3::exceptions::PyNotImplementedError; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyBytes; +use serde::Deserialize; +use serde::Serialize; use tokio::sync::Mutex; use crate::actor::PythonActor; @@ -101,6 +106,14 @@ impl PythonActorMesh { .borrow() .map_err(|_| PyRuntimeError::new_err("`PythonActorMesh` has already been stopped")) } + + fn pickling_err(&self) -> PyErr { + PyErr::new::( + "PythonActorMesh cannot be pickled. If applicable, use bind() \ + to get a PythonActorMeshRef, and use that instead." + .to_string(), + ) + } } #[pymethods] @@ -123,6 +136,11 @@ impl PythonActorMesh { Ok(()) } + fn bind(&self) -> PyResult { + let mesh = self.try_inner()?; + Ok(PythonActorMeshRef { inner: mesh.bind() }) + } + fn get_supervision_event(&self) -> PyResult> { let unhealthy_event = self .unhealthy_event @@ -169,6 +187,64 @@ impl PythonActorMesh { fn shape(&self) -> PyResult { Ok(PyShape::from(self.try_inner()?.shape().clone())) } + + // Override the pickling methods to provide a meaningful error message. + fn __reduce__(&self) -> PyResult<()> { + Err(self.pickling_err()) + } + + fn __reduce_ex__(&self, _proto: u8) -> PyResult<()> { + Err(self.pickling_err()) + } +} + +#[pyclass( + frozen, + name = "PythonActorMeshRef", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +#[derive(Debug, Serialize, Deserialize)] +pub(super) struct PythonActorMeshRef { + inner: ActorMeshRef, +} + +#[pymethods] +impl PythonActorMeshRef { + fn cast( + &self, + client: &PyMailbox, + selection: &PySelection, + message: &PythonMessage, + ) -> PyResult<()> { + self.inner + .cast(&client.inner, selection.inner().clone(), message.clone()) + .map_err(|err| PyException::new_err(err.to_string()))?; + Ok(()) + } + + #[getter] + fn shape(&self) -> PyShape { + PyShape::from(self.inner.shape().clone()) + } + + #[staticmethod] + fn from_bytes(bytes: &Bound<'_, PyBytes>) -> PyResult { + bincode::deserialize(bytes.as_bytes()) + .map_err(|e| PyErr::new::(e.to_string())) + } + + fn __reduce__<'py>( + slf: &Bound<'py, Self>, + ) -> PyResult<(Bound<'py, PyAny>, (Bound<'py, PyBytes>,))> { + let bytes = bincode::serialize(&*slf.borrow()) + .map_err(|e| PyErr::new::(e.to_string()))?; + let py_bytes = PyBytes::new(slf.py(), &bytes); + Ok((slf.as_any().getattr("from_bytes")?, (py_bytes,))) + } + + fn __repr__(&self) -> String { + format!("{:?}", self) + } } impl Drop for PythonActorMesh { @@ -379,6 +455,7 @@ impl From for PyActorSupervisionEvent { pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; diff --git a/monarch_hyperactor/src/shape.rs b/monarch_hyperactor/src/shape.rs index b01d391c..ad75ea5f 100644 --- a/monarch_hyperactor/src/shape.rs +++ b/monarch_hyperactor/src/shape.rs @@ -21,6 +21,7 @@ use crate::ndslice::PySlice; module = "monarch._rust_bindings.monarch_hyperactor.shape", frozen )] +#[derive(Clone)] pub struct PyShape { pub(super) inner: Shape, } @@ -112,6 +113,14 @@ impl PyShape { self.inner.slice().len() } + fn __eq__(&self, other: &Bound<'_, PyAny>) -> PyResult { + if let Ok(other) = other.extract::() { + Ok(self.inner == other.inner) + } else { + Ok(false) + } + } + #[staticmethod] fn unity() -> PyShape { Shape::unity().into() diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index 70ad89cd..18f3eb0c 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -6,6 +6,7 @@ # pyre-strict +from collections.abc import Mapping from typing import AsyncIterator, final from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage @@ -18,8 +19,36 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.selection import Selection from monarch._rust_bindings.monarch_hyperactor.shape import Shape +@final +class PythonActorMeshRef: + """ + A reference to a remote actor mesh over which PythonMessages can be sent. + """ + + def cast( + self, mailbox: Mailbox, selection: Selection, message: PythonMessage + ) -> None: + """Cast a message to the selected actors in the mesh.""" + ... + + @property + def shape(self) -> Shape: + """ + The Shape object that describes how the rank of an actor + retrieved with get corresponds to coordinates in the + mesh. + """ + ... + @final class PythonActorMesh: + def bind(self) -> PythonActorMeshRef: + """ + Bind this actor mesh. The returned mesh ref can be used to reach the + mesh remotely. + """ + ... + def cast(self, selection: Selection, message: PythonMessage) -> None: """ Cast a message to the selected actors in the mesh. diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi index bf0db8cc..6e5e582d 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/bootstrap.pyi @@ -4,4 +4,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + def bootstrap_main() -> None: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi index 5882279e..1c566029 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/selection.pyi @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from typing import final @final diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi index b13de2f6..d221556d 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi @@ -4,8 +4,10 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import collections.abc -from typing import Dict, final, Iterator, List, overload, Sequence +from typing import Any, Dict, final, Iterator, List, overload, Sequence @final class Slice: @@ -71,11 +73,11 @@ class Slice: def __eq__(self, value: object) -> bool: ... def __hash__(self) -> int: ... - def __getnewargs_ex__(self) -> tuple[tuple, dict]: ... + def __getnewargs_ex__(self) -> tuple[tuple[Any], dict[Any, Any]]: ... @overload def __getitem__(self, i: int) -> int: ... @overload - def __getitem__(self, i: slice) -> tuple[int, ...]: ... + def __getitem__(self, i: slice[Any, Any, Any]) -> tuple[int, ...]: ... def __len__(self) -> int: """Returns the complete size of the slice.""" ... @@ -136,6 +138,7 @@ class Shape: """ ... def __len__(self) -> int: ... + def __eq__(self, value: object) -> bool: ... @staticmethod def unity() -> "Shape": ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi index a9c6fc87..36ec0a21 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/telemetry.pyi @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + import logging def forward_to_tracing(record: logging.LogRecord) -> None: diff --git a/python/tests/_monarch/test_actor_mesh.py b/python/tests/_monarch/test_actor_mesh.py new file mode 100644 index 00000000..b90c3803 --- /dev/null +++ b/python/tests/_monarch/test_actor_mesh.py @@ -0,0 +1,110 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import pickle +from typing import List + +import monarch +import pytest + +from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( + PythonActorMesh, + PythonActorMeshRef, +) + +from monarch._rust_bindings.monarch_hyperactor.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension + AllocConstraints, + AllocSpec, +) + +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox, PortReceiver +from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcMesh +from monarch._rust_bindings.monarch_hyperactor.selection import Selection +from monarch._rust_bindings.monarch_hyperactor.shape import Shape + + +async def allocate() -> ProcMesh: + spec = AllocSpec(AllocConstraints(), replica=2, gpus=3, hosts=8) + allocator = monarch.LocalAllocator() + alloc = await allocator.allocate(spec) + proc_mesh = await ProcMesh.allocate_nonblocking(alloc) + return proc_mesh + + +class MyActor: + async def handle( + self, + mailbox: Mailbox, + rank: int, + shape: Shape, + message: PythonMessage, + panic_flag: PanicFlag, + ) -> None: + assert rank is not None + reply_port = message.response_port + assert reply_port is not None + reply_port.send( + mailbox, PythonMessage("pong", pickle.dumps(f"rank: {rank}"), None, rank) + ) + + +async def test_bind_and_pickling() -> None: + proc_mesh = await allocate() + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + with pytest.raises(NotImplementedError, match="use bind()"): + pickle.dumps(actor_mesh) + + actor_mesh_ref = actor_mesh.bind() + assert actor_mesh_ref.shape == actor_mesh.shape + obj = pickle.dumps(actor_mesh_ref) + unpickled = pickle.loads(obj) + assert repr(actor_mesh_ref) == repr(unpickled) + assert actor_mesh_ref.shape == unpickled.shape + + +async def verify_cast( + actor_mesh: PythonActorMesh | PythonActorMeshRef, + mailbox: Mailbox, + cast_ranks: List[int], +) -> None: + receiver: PortReceiver + handle, receiver = mailbox.open_port() + port_ref = handle.bind() + + message = PythonMessage("echo", pickle.dumps("ping"), port_ref, None) + sel = Selection.from_string("*") + if isinstance(actor_mesh, PythonActorMesh): + actor_mesh.cast(sel, message) + elif isinstance(actor_mesh, PythonActorMeshRef): + actor_mesh.cast(mailbox, sel, message) + + rcv_ranks = [] + for _ in range(len(cast_ranks)): + message = await receiver.recv() + rank = message.rank + assert rank is not None + rcv_ranks.append(rank) + rcv_ranks.sort() + for i in cast_ranks: + assert rcv_ranks[i] == i + + +@pytest.mark.timeout(30) +async def test_cast_handle() -> None: + proc_mesh = await allocate() + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + await verify_cast(actor_mesh, proc_mesh.client, list(range(2 * 3 * 8))) + + +@pytest.mark.timeout(30) +async def test_cast_ref() -> None: + proc_mesh = await allocate() + actor_mesh = await proc_mesh.spawn_nonblocking("test", MyActor) + actor_mesh_ref = actor_mesh.bind() + await verify_cast(actor_mesh_ref, proc_mesh.client, list(range(2 * 3 * 8)))