diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index c36057c3..69bb7ca1 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -45,7 +45,6 @@ use crate::CommActor; use crate::Mesh; use crate::comm::multicast::CastMessage; use crate::comm::multicast::CastMessageEnvelope; -use crate::comm::multicast::DestinationPort; use crate::comm::multicast::Uslice; use crate::comm::multicast::set_cast_info_on_headers; use crate::metrics; @@ -54,21 +53,21 @@ use crate::reference::ActorMeshId; use crate::reference::ActorMeshRef; use crate::reference::ProcMeshId; -/// Common implementation for ActorMeshes and ActorMeshRefs to cast an [`M`]-typed message +/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast +/// an `M`-typed message #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. -pub(crate) fn actor_mesh_cast( +pub(crate) fn actor_mesh_cast( caps: &impl cap::CanSend, actor_mesh_id: ActorMeshId, actor_mesh_shape: &Shape, - _proc_mesh_shape: &Shape, - actor_name: &str, sender: &ActorId, comm_actor_ref: &ActorRef, selection: Selection, message: M, ) -> Result<(), CastError> where - A: RemoteHandles + RemoteHandles>, + A: RemoteActor + RemoteHandles>, + M: Castable + RemoteMessage, { let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!( "message_type" => M::typename(), @@ -76,10 +75,9 @@ where )); let slice = actor_mesh_shape.slice().clone(); - let message = CastMessageEnvelope::new( + let message = CastMessageEnvelope::new::( actor_mesh_id, sender.clone(), - DestinationPort::new::(actor_name.to_string()), actor_mesh_shape.clone(), message, None, // TODO: reducer typehash @@ -101,30 +99,29 @@ pub trait ActorMesh: Mesh { /// The type of actor in the mesh. type Actor: RemoteActor; - /// Cast an [`M`]-typed message to the ranks selected by `sel` - /// in this ActorMesh. + /// Cast an `M`-typed message to the ranks selected by `sel` in + /// this ActorMesh. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. - fn cast(&self, selection: Selection, message: M) -> Result<(), CastError> + fn cast(&self, selection: Selection, message: M) -> Result<(), CastError> where - Self::Actor: RemoteHandles + RemoteHandles>, + Self::Actor: RemoteHandles>, + M: Castable + RemoteMessage, { - actor_mesh_cast::( - self.proc_mesh().client(), - self.id(), - self.shape(), - self.proc_mesh().shape(), - self.name(), - self.proc_mesh().client().actor_id(), - self.proc_mesh().comm_actor(), - selection, - message, + actor_mesh_cast::( + self.proc_mesh().client(), // send capability + self.id(), // actor mesh id (destination mesh) + self.shape(), // actor mesh shape + self.proc_mesh().client().actor_id(), // sender + self.proc_mesh().comm_actor(), // comm actor + selection, // the selected actors + message, // the message ) } /// The ProcMesh on top of which this actor mesh is spawned. fn proc_mesh(&self) -> &ProcMesh; - /// The name global name of actors in this mesh. + /// The name given to the actors in this mesh. fn name(&self) -> &str; fn world_id(&self) -> &WorldId { diff --git a/hyperactor_mesh/src/comm/multicast.rs b/hyperactor_mesh/src/comm/multicast.rs index 56e0267e..371f8e39 100644 --- a/hyperactor_mesh/src/comm/multicast.rs +++ b/hyperactor_mesh/src/comm/multicast.rs @@ -62,19 +62,23 @@ pub struct CastMessageEnvelope { impl CastMessageEnvelope { /// Create a new CastMessageEnvelope. - pub fn new( + pub fn new( actor_mesh_id: ActorMeshId, sender: ActorId, - dest_port: DestinationPort, shape: Shape, - message: T, + message: M, reducer_typehash: Option, - ) -> Result { + ) -> Result + where + A: RemoteActor + RemoteHandles>, + M: Castable + RemoteMessage, + { let data = ErasedUnbound::try_from_message(message)?; + let actor_name = actor_mesh_id.1.to_string(); Ok(Self { actor_mesh_id, sender, - dest_port, + dest_port: DestinationPort::new::(actor_name), data, reducer_typehash, shape, diff --git a/hyperactor_mesh/src/reference.rs b/hyperactor_mesh/src/reference.rs index 2777e2d2..83c7f795 100644 --- a/hyperactor_mesh/src/reference.rs +++ b/hyperactor_mesh/src/reference.rs @@ -14,6 +14,7 @@ use std::marker::PhantomData; use hyperactor::ActorRef; use hyperactor::Named; use hyperactor::RemoteHandles; +use hyperactor::RemoteMessage; use hyperactor::actor::RemoteActor; use hyperactor::cap; use hyperactor::message::Castable; @@ -54,7 +55,7 @@ macro_rules! mesh_id { )] pub struct ProcMeshId(pub String); -/// Actor Mesh ID. Tuple of the ProcMesh ID and Actor Mesh ID. +/// Actor Mesh ID. Tuple of the ProcMesh ID and actor name. #[derive( Debug, Serialize, @@ -111,19 +112,10 @@ impl ActorMeshRef { &self.shape } - /// Shape of the underlying Proc Mesh. - fn proc_mesh_shape(&self) -> &Shape { - &self.proc_mesh_shape - } - - fn name(&self) -> &str { - &self.mesh_id.1 - } - /// Cast an [`M`]-typed message to the ranks selected by `sel` /// in this ActorMesh. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`. - pub fn cast( + pub fn cast( &self, caps: &(impl cap::CanSend + cap::CanOpenPort), selection: Selection, @@ -131,13 +123,12 @@ impl ActorMeshRef { ) -> Result<(), CastError> where A: RemoteHandles + RemoteHandles>, + M: Castable + RemoteMessage, { - actor_mesh_cast::( + actor_mesh_cast::( caps, self.mesh_id.clone(), self.shape(), - self.proc_mesh_shape(), - self.name(), caps.mailbox().actor_id(), &self.comm_actor_ref, selection, diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index a440032b..336148af 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -7,7 +7,6 @@ */ use std::collections::BTreeMap; -use std::collections::BTreeSet; use std::collections::HashMap; use std::collections::HashSet; use std::collections::VecDeque; @@ -55,9 +54,9 @@ use monarch_messages::worker::WorkerMessage; use monarch_messages::worker::WorkerParams; use monarch_tensor_worker::AssignRankMessage; use monarch_tensor_worker::WorkerActor; -use ndslice::Selection; use ndslice::Slice; use ndslice::selection; +use ndslice::selection::ReifyView; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; use tokio::sync::Mutex; @@ -808,19 +807,8 @@ impl Handler for MeshControllerActor { ) -> anyhow::Result<()> { match message { ClientToControllerMessage::Send { slices, message } => { - let selection = slices - .iter() - .map(|slice| { - Selection::of_ranks( - self.workers().shape().slice(), - &slice.iter().collect::>(), - ) - }) - .collect::, _>>()? - .into_iter() - .reduce(selection::dsl::union) - .unwrap_or_else(selection::dsl::false_); - self.workers().cast(selection, message.clone())?; + let sel = self.workers().shape().slice().reify_views(slices)?; + self.workers().cast(sel, message)?; } ClientToControllerMessage::Node { seq, diff --git a/ndslice/src/selection.rs b/ndslice/src/selection.rs index c2bbe2ab..5495f8df 100644 --- a/ndslice/src/selection.rs +++ b/ndslice/src/selection.rs @@ -1013,7 +1013,7 @@ pub trait ReifyView: sealed::Sealed { /// Reify multiple views as a union of selections in the /// coordinate system of `self`. - fn reify_views(&self, views: &[&Slice]) -> Result; + fn reify_views>(&self, views: V) -> Result; } impl ReifyView for Slice { @@ -1030,7 +1030,6 @@ impl ReifyView for Slice { /// # Errors /// /// Returns an error if: - /// - The number of dimensions in the view does not match the base /// - The view lies outside the bounds of the base slice /// /// # Example @@ -1044,16 +1043,14 @@ impl ReifyView for Slice { /// let selection = base.reify_view(view).unwrap(); /// ``` fn reify_view(&self, view: &Slice) -> Result { - if view.num_dim() != self.num_dim() { - return Err(SliceError::InvalidDims { - expected: self.num_dim(), - got: view.num_dim(), - }); - } if view.is_empty() { return Ok(dsl::false_()); } + if view.num_dim() != self.num_dim() { + return Selection::of_ranks(self, &view.iter().collect::>()); + } + let origin = self.coordinates(view.offset())?; let mut acc = dsl::true_(); for (&start, &len) in origin.iter().zip(view.sizes()).rev() { @@ -1076,7 +1073,6 @@ impl ReifyView for Slice { /// # Errors /// /// Returns an error if any view: - /// - Has a different number of dimensions than the base slice /// - Refers to coordinates not contained within the base /// /// # Example @@ -1087,15 +1083,22 @@ impl ReifyView for Slice { /// let shape = ndslice::shape!(x = 4, y = 4); /// let base = shape.slice(); /// - /// let a = ndslice::select!(shape, x = 0..2, y = 0..2).unwrap(); - /// let b = ndslice::select!(shape, x = 2..4, y = 2..4).unwrap(); + /// let a = ndslice::select!(shape, x = 0..2, y = 0..2) + /// .unwrap() + /// .slice() + /// .clone(); + /// let b = ndslice::select!(shape, x = 2..4, y = 2..4) + /// .unwrap() + /// .slice() + /// .clone(); /// - /// let sel = base.reify_views(&[a.slice(), b.slice()]).unwrap(); + /// let sel = base.reify_views(&[a, b]).unwrap(); /// ``` - fn reify_views(&self, views: &[&Slice]) -> Result { + fn reify_views>(&self, views: V) -> Result { + let views = views.as_ref(); let mut selections = Vec::with_capacity(views.len()); - for &view in views { + for view in views { if view.is_empty() { continue; } @@ -2157,6 +2160,28 @@ mod tests { ); } + #[test] + fn test_reify_view_dimension_mismatch() { + let shape = shape!(host = 2, gpu = 4); + let base = shape.slice(); + + // Select the 3rd GPU (index 2) across both hosts i.e. flat + // indices [2, 6] + let indices = vec![ + base.location(&[0, 2]).unwrap(), + base.location(&[1, 2]).unwrap(), + ]; + + let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap(); + let selection = base.reify_view(&view).unwrap(); + + let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap(); + assert_structurally_eq!(&selection, expected); + + let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect(); + assert_eq!(actual, indices); + } + #[test] fn test_union_of_slices_empty() { let base = Slice::new_row_major([2]); @@ -2175,7 +2200,7 @@ mod tests { let shape = shape!(x = 3); let base = shape.slice(); let selected = select!(shape, x = 1).unwrap(); - let view = selected.slice(); + let view = selected.slice().clone(); let selection = base.reify_views(&[view]).unwrap(); let expected = range(1..=1, true_()); @@ -2197,11 +2222,11 @@ mod tests { // View A: (0, *) let a = select!(shape, x = 0).unwrap(); - let view_a = a.slice(); + let view_a = a.slice().clone(); // View B: (1, *) let b = select!(shape, x = 1).unwrap(); - let view_b = b.slice(); + let view_b = b.slice().clone(); let selection = base.reify_views(&[view_a, view_b]).unwrap(); let expected = union( @@ -2224,10 +2249,10 @@ mod tests { let base = shape.slice(); let selected1 = select!(shape, y = 0..2).unwrap(); - let view1 = selected1.slice(); + let view1 = selected1.slice().clone(); let selected2 = select!(shape, y = 1..4).unwrap(); - let view2 = selected2.slice(); + let view2 = selected2.slice().clone(); let selection = base.reify_views(&[view1, view2]).unwrap(); let expected = union(