diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 49a645de..166692ab 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -451,7 +451,7 @@ pub(crate) mod test_util { cx: &Context, GetRank(ok, reply): GetRank, ) -> Result<(), anyhow::Error> { - let (rank, _) = cx.cast_info()?; + let (rank, _) = cx.cast_info(); reply.send(cx, rank)?; anyhow::ensure!(ok, "intentional error!"); // If `!ok` exit with `Err()`. Ok(()) diff --git a/hyperactor_mesh/src/comm/multicast.rs b/hyperactor_mesh/src/comm/multicast.rs index 8cc6068d..09f19165 100644 --- a/hyperactor_mesh/src/comm/multicast.rs +++ b/hyperactor_mesh/src/comm/multicast.rs @@ -196,32 +196,20 @@ pub fn set_cast_info_on_headers(headers: &mut Attrs, rank: usize, shape: Shape, } pub trait CastInfo { - /// Get the cast rank and cast shape, returning an error - /// if the relevant info isn't available. - fn cast_info(&self) -> anyhow::Result<(usize, Shape)>; - - /// Get the cast rank and cast shape, returning None - /// if the relevant info isn't available. - fn maybe_cast_info(&self) -> Option<(usize, Shape)>; + /// Get the cast rank and cast shape. + /// If something wasn't explicitly sent via a cast, then + /// we represent it as the only member of a 0-dimensonal cast shape, + /// which is the same as a singleton. + fn cast_info(&self) -> (usize, Shape); } impl CastInfo for Context<'_, A> { - fn cast_info(&self) -> anyhow::Result<(usize, Shape)> { + fn cast_info(&self) -> (usize, Shape) { let headers = self.headers(); - let rank = headers - .get(CAST_RANK) - .ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_RANK.name()))?; - let shape = headers - .get(CAST_SHAPE) - .ok_or_else(|| anyhow::anyhow!("{} not found in headers", CAST_SHAPE.name()))? - .clone(); - Ok((*rank, shape)) - } - - fn maybe_cast_info(&self) -> Option<(usize, Shape)> { - let headers = self.headers(); - headers - .get(CAST_RANK) - .map(|rank| headers.get(CAST_SHAPE).map(|shape| (*rank, shape.clone())))? + match (headers.get(CAST_RANK), headers.get(CAST_SHAPE)) { + (Some(rank), Some(shape)) => (*rank, shape.clone()), + (None, None) => (0, Shape::unity()), + _ => panic!("Expected either both rank and shape or neither"), + } } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 2a7bdcb9..3a50f95a 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -28,6 +28,7 @@ use hyperactor::message::Unbind; use hyperactor_mesh::comm::multicast::CastInfo; use monarch_types::PickledPyObject; use monarch_types::SerializablePyErr; +use ndslice::Shape; use pyo3::conversion::IntoPyObjectExt; use pyo3::exceptions::PyBaseException; use pyo3::exceptions::PyRuntimeError; @@ -408,34 +409,21 @@ impl Handler for PythonActor { let (sender, receiver) = oneshot::channel(); let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> { - let awaitable = match cx.maybe_cast_info() { - Some((rank, shape)) => self.actor.call_method( - py, - "handle_cast", - ( - mailbox, - rank, - PyShape::from(shape), - message, - PanicFlag { - sender: Some(sender), - }, - ), - None, - )?, - None => self.actor.call_method( - py, - "handle", - ( - mailbox, - message, - PanicFlag { - sender: Some(sender), - }, - ), - None, - )?, - }; + let (rank, shape) = cx.cast_info(); + let awaitable = self.actor.call_method( + py, + "handle", + ( + mailbox, + rank, + PyShape::from(shape), + message, + PanicFlag { + sender: Some(sender), + }, + ), + None, + )?; pyo3_async_runtimes::into_future_with_locals( self.get_task_locals(py), diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index f1438067..13dfa090 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -266,7 +266,7 @@ impl Handler for WorkerActor { cx: &hyperactor::Context, _: AssignRankMessage, ) -> anyhow::Result<()> { - let (rank, shape) = cx.cast_info()?; + let (rank, shape) = cx.cast_info(); self.rank = rank; self.respond_with_python_message = true; Python::with_gil(|py| { diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index 2732494d..12f84386 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -179,9 +179,6 @@ class PanicFlag: class Actor(Protocol): async def handle( - self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag - ) -> None: ... - async def handle_cast( self, mailbox: Mailbox, rank: int, diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 22dfbe42..a1bb3c64 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -571,11 +571,6 @@ def __init__(self) -> None: self.instance: object | None = None async def handle( - self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag - ) -> None: - return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag) - - async def handle_cast( self, mailbox: Mailbox, rank: int, diff --git a/python/tests/_monarch/test_hyperactor.py b/python/tests/_monarch/test_hyperactor.py index 602b1867..23605075 100644 --- a/python/tests/_monarch/test_hyperactor.py +++ b/python/tests/_monarch/test_hyperactor.py @@ -27,11 +27,6 @@ class MyActor: async def handle( - self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag - ) -> None: - raise NotImplementedError() - - async def handle_cast( self, mailbox: Mailbox, rank: int, diff --git a/python/tests/_monarch/test_mailbox.py b/python/tests/_monarch/test_mailbox.py index 8027f9e9..6a7fb48a 100644 --- a/python/tests/_monarch/test_mailbox.py +++ b/python/tests/_monarch/test_mailbox.py @@ -121,11 +121,6 @@ async def recv_message() -> str: class MyActor: async def handle( - self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag - ) -> None: - return None - - async def handle_cast( self, mailbox: Mailbox, rank: int,