From 7a8fb6ab0ae9d80ce39b649f4030ab6abd7b8ca9 Mon Sep 17 00:00:00 2001 From: Shayne Fletcher Date: Tue, 8 Jul 2025 08:18:53 -0700 Subject: [PATCH] actor_mesh: use commactor in cast_slices (#461) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/461 refactor `cast_slices`. the old approach iterated over the input slices and for each rank in each slice, sent a message. now, call `reify_views` on the input slices to produce a single `Selection` (union of ranges). then, cast the message to the selected ranks via the comm actor. Differential Revision: D77935610 --- hyperactor_mesh/src/actor_mesh.rs | 35 +++++++++++++------------------ 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 49a645de..0258521f 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -25,7 +25,6 @@ use hyperactor::RemoteMessage; use hyperactor::Unbind; use hyperactor::WorldId; use hyperactor::actor::RemoteActor; -use hyperactor::attrs::Attrs; use hyperactor::cap; use hyperactor::mailbox::MailboxSenderError; use hyperactor::mailbox::PortReceiver; @@ -49,7 +48,6 @@ use crate::comm::multicast::CastMessage; use crate::comm::multicast::CastMessageEnvelope; use crate::comm::multicast::DestinationPort; use crate::comm::multicast::Uslice; -use crate::comm::multicast::set_cast_info_on_headers; use crate::metrics; use crate::proc_mesh::ProcMesh; use crate::reference::ActorMeshId; @@ -237,32 +235,26 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { /// Until the selection logic is more powerful, we need a way to /// replicate the send patterns that the worker actor mesh actually does. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. - pub fn cast_slices( - &self, - sel: Vec, - message: M, - ) -> Result<(), CastError> + pub fn cast_slices(&self, sel: Vec, message: M) -> Result<(), CastError> where + M: Castable + Clone, A: RemoteHandles + RemoteHandles>, { let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!( "message_type" => M::typename(), "message_variant" => message.arm().unwrap_or_default(), )); - for ref slice in sel { - for rank in slice.iter() { - let mut headers = Attrs::new(); - set_cast_info_on_headers( - &mut headers, - rank, - self.shape().clone(), - self.proc_mesh.client().actor_id().clone(), - ); - self.ranks[rank] - .send_with_headers(self.proc_mesh.client(), headers, message.clone()) - .map_err(|err| CastError::MailboxSenderError(rank, err))?; - } - } + + let slices: &[&Slice] = &sel.iter().collect::>(); + let selection = self + .proc_mesh + .shape() + .slice() + .reify_views(slices) + .expect("invalid slices"); + + self.cast(selection, message)?; + Ok(()) } @@ -528,6 +520,7 @@ mod tests { use $crate::sel_from_shape; use $crate::sel; use $crate::proc_mesh::SharedSpawnable; + use $crate::comm::multicast::set_cast_info_on_headers; use std::collections::VecDeque; use hyperactor::data::Serialized;