Skip to content

: selection: handle dim mismatch in reify_view #482

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 20 additions & 23 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ use crate::CommActor;
use crate::Mesh;
use crate::comm::multicast::CastMessage;
use crate::comm::multicast::CastMessageEnvelope;
use crate::comm::multicast::DestinationPort;
use crate::comm::multicast::Uslice;
use crate::comm::multicast::set_cast_info_on_headers;
use crate::metrics;
Expand All @@ -54,32 +53,31 @@ use crate::reference::ActorMeshId;
use crate::reference::ActorMeshRef;
use crate::reference::ProcMeshId;

/// Common implementation for ActorMeshes and ActorMeshRefs to cast an [`M`]-typed message
/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
/// an `M`-typed message
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub(crate) fn actor_mesh_cast<M: Castable + Clone, A>(
pub(crate) fn actor_mesh_cast<A, M>(
caps: &impl cap::CanSend,
actor_mesh_id: ActorMeshId,
actor_mesh_shape: &Shape,
_proc_mesh_shape: &Shape,
actor_name: &str,
sender: &ActorId,
comm_actor_ref: &ActorRef<CommActor>,
selection: Selection,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
"message_type" => M::typename(),
"message_variant" => message.arm().unwrap_or_default(),
));

let slice = actor_mesh_shape.slice().clone();
let message = CastMessageEnvelope::new(
let message = CastMessageEnvelope::new::<A, M>(
actor_mesh_id,
sender.clone(),
DestinationPort::new::<A, M>(actor_name.to_string()),
actor_mesh_shape.clone(),
message,
None, // TODO: reducer typehash
Expand All @@ -101,30 +99,29 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
/// The type of actor in the mesh.
type Actor: RemoteActor;

/// Cast an [`M`]-typed message to the ranks selected by `sel`
/// in this ActorMesh.
/// Cast an `M`-typed message to the ranks selected by `sel` in
/// this ActorMesh.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
fn cast<M: Castable + Clone>(&self, selection: Selection, message: M) -> Result<(), CastError>
fn cast<M>(&self, selection: Selection, message: M) -> Result<(), CastError>
where
Self::Actor: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
actor_mesh_cast::<M, Self::Actor>(
self.proc_mesh().client(),
self.id(),
self.shape(),
self.proc_mesh().shape(),
self.name(),
self.proc_mesh().client().actor_id(),
self.proc_mesh().comm_actor(),
selection,
message,
actor_mesh_cast::<Self::Actor, M>(
self.proc_mesh().client(), // send capability
self.id(), // actor mesh id (destination mesh)
self.shape(), // actor mesh shape
self.proc_mesh().client().actor_id(), // sender
self.proc_mesh().comm_actor(), // comm actor
selection, // the selected actors
message, // the message
)
}

/// The ProcMesh on top of which this actor mesh is spawned.
fn proc_mesh(&self) -> &ProcMesh;

/// The name global name of actors in this mesh.
/// The name given to the actors in this mesh.
fn name(&self) -> &str;

fn world_id(&self) -> &WorldId {
Expand Down
14 changes: 9 additions & 5 deletions hyperactor_mesh/src/comm/multicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ pub struct CastMessageEnvelope {

impl CastMessageEnvelope {
/// Create a new CastMessageEnvelope.
pub fn new<T: Castable + Serialize + Named>(
pub fn new<A, M>(
actor_mesh_id: ActorMeshId,
sender: ActorId,
dest_port: DestinationPort,
shape: Shape,
message: T,
message: M,
reducer_typehash: Option<u64>,
) -> Result<Self, anyhow::Error> {
) -> Result<Self, anyhow::Error>
where
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let data = ErasedUnbound::try_from_message(message)?;
let actor_name = actor_mesh_id.1.to_string();
Ok(Self {
actor_mesh_id,
sender,
dest_port,
dest_port: DestinationPort::new::<A, M>(actor_name),
data,
reducer_typehash,
shape,
Expand Down
19 changes: 5 additions & 14 deletions hyperactor_mesh/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::marker::PhantomData;
use hyperactor::ActorRef;
use hyperactor::Named;
use hyperactor::RemoteHandles;
use hyperactor::RemoteMessage;
use hyperactor::actor::RemoteActor;
use hyperactor::cap;
use hyperactor::message::Castable;
Expand Down Expand Up @@ -54,7 +55,7 @@ macro_rules! mesh_id {
)]
pub struct ProcMeshId(pub String);

/// Actor Mesh ID. Tuple of the ProcMesh ID and Actor Mesh ID.
/// Actor Mesh ID. Tuple of the ProcMesh ID and actor name.
#[derive(
Debug,
Serialize,
Expand Down Expand Up @@ -111,33 +112,23 @@ impl<A: RemoteActor> ActorMeshRef<A> {
&self.shape
}

/// Shape of the underlying Proc Mesh.
fn proc_mesh_shape(&self) -> &Shape {
&self.proc_mesh_shape
}

fn name(&self) -> &str {
&self.mesh_id.1
}

/// Cast an [`M`]-typed message to the ranks selected by `sel`
/// in this ActorMesh.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub fn cast<M: Castable + Clone>(
pub fn cast<M>(
&self,
caps: &(impl cap::CanSend + cap::CanOpenPort),
selection: Selection,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
actor_mesh_cast::<M, A>(
actor_mesh_cast::<A, M>(
caps,
self.mesh_id.clone(),
self.shape(),
self.proc_mesh_shape(),
self.name(),
caps.mailbox().actor_id(),
&self.comm_actor_ref,
selection,
Expand Down
18 changes: 3 additions & 15 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
*/

use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
Expand Down Expand Up @@ -55,9 +54,9 @@ use monarch_messages::worker::WorkerMessage;
use monarch_messages::worker::WorkerParams;
use monarch_tensor_worker::AssignRankMessage;
use monarch_tensor_worker::WorkerActor;
use ndslice::Selection;
use ndslice::Slice;
use ndslice::selection;
use ndslice::selection::ReifyView;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -808,19 +807,8 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
) -> anyhow::Result<()> {
match message {
ClientToControllerMessage::Send { slices, message } => {
let selection = slices
.iter()
.map(|slice| {
Selection::of_ranks(
self.workers().shape().slice(),
&slice.iter().collect::<BTreeSet<usize>>(),
)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.reduce(selection::dsl::union)
.unwrap_or_else(selection::dsl::false_);
self.workers().cast(selection, message.clone())?;
let sel = self.workers().shape().slice().reify_views(slices)?;
self.workers().cast(sel, message)?;
}
ClientToControllerMessage::Node {
seq,
Expand Down
63 changes: 44 additions & 19 deletions ndslice/src/selection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1013,7 +1013,7 @@ pub trait ReifyView: sealed::Sealed {

/// Reify multiple views as a union of selections in the
/// coordinate system of `self`.
fn reify_views(&self, views: &[&Slice]) -> Result<Selection, SliceError>;
fn reify_views<V: AsRef<[Slice]>>(&self, views: V) -> Result<Selection, SliceError>;
}

impl ReifyView for Slice {
Expand All @@ -1030,7 +1030,6 @@ impl ReifyView for Slice {
/// # Errors
///
/// Returns an error if:
/// - The number of dimensions in the view does not match the base
/// - The view lies outside the bounds of the base slice
///
/// # Example
Expand All @@ -1044,16 +1043,14 @@ impl ReifyView for Slice {
/// let selection = base.reify_view(view).unwrap();
/// ```
fn reify_view(&self, view: &Slice) -> Result<Selection, SliceError> {
if view.num_dim() != self.num_dim() {
return Err(SliceError::InvalidDims {
expected: self.num_dim(),
got: view.num_dim(),
});
}
if view.is_empty() {
return Ok(dsl::false_());
}

if view.num_dim() != self.num_dim() {
return Selection::of_ranks(self, &view.iter().collect::<BTreeSet<usize>>());
}

let origin = self.coordinates(view.offset())?;
let mut acc = dsl::true_();
for (&start, &len) in origin.iter().zip(view.sizes()).rev() {
Expand All @@ -1076,7 +1073,6 @@ impl ReifyView for Slice {
/// # Errors
///
/// Returns an error if any view:
/// - Has a different number of dimensions than the base slice
/// - Refers to coordinates not contained within the base
///
/// # Example
Expand All @@ -1087,15 +1083,22 @@ impl ReifyView for Slice {
/// let shape = ndslice::shape!(x = 4, y = 4);
/// let base = shape.slice();
///
/// let a = ndslice::select!(shape, x = 0..2, y = 0..2).unwrap();
/// let b = ndslice::select!(shape, x = 2..4, y = 2..4).unwrap();
/// let a = ndslice::select!(shape, x = 0..2, y = 0..2)
/// .unwrap()
/// .slice()
/// .clone();
/// let b = ndslice::select!(shape, x = 2..4, y = 2..4)
/// .unwrap()
/// .slice()
/// .clone();
///
/// let sel = base.reify_views(&[a.slice(), b.slice()]).unwrap();
/// let sel = base.reify_views(&[a, b]).unwrap();
/// ```
fn reify_views(&self, views: &[&Slice]) -> Result<Selection, SliceError> {
fn reify_views<V: AsRef<[Slice]>>(&self, views: V) -> Result<Selection, SliceError> {
let views = views.as_ref();
let mut selections = Vec::with_capacity(views.len());

for &view in views {
for view in views {
if view.is_empty() {
continue;
}
Expand Down Expand Up @@ -2157,6 +2160,28 @@ mod tests {
);
}

#[test]
fn test_reify_view_dimension_mismatch() {
let shape = shape!(host = 2, gpu = 4);
let base = shape.slice();

// Select the 3rd GPU (index 2) across both hosts i.e. flat
// indices [2, 6]
let indices = vec![
base.location(&[0, 2]).unwrap(),
base.location(&[1, 2]).unwrap(),
];

let view = Slice::new(indices[0], vec![indices.len()], vec![4]).unwrap();
let selection = base.reify_view(&view).unwrap();

let expected = Selection::of_ranks(base, &indices.iter().cloned().collect()).unwrap();
assert_structurally_eq!(&selection, expected);

let actual: Vec<_> = selection.eval(&EvalOpts::strict(), base).unwrap().collect();
assert_eq!(actual, indices);
}

#[test]
fn test_union_of_slices_empty() {
let base = Slice::new_row_major([2]);
Expand All @@ -2175,7 +2200,7 @@ mod tests {
let shape = shape!(x = 3);
let base = shape.slice();
let selected = select!(shape, x = 1).unwrap();
let view = selected.slice();
let view = selected.slice().clone();

let selection = base.reify_views(&[view]).unwrap();
let expected = range(1..=1, true_());
Expand All @@ -2197,11 +2222,11 @@ mod tests {

// View A: (0, *)
let a = select!(shape, x = 0).unwrap();
let view_a = a.slice();
let view_a = a.slice().clone();

// View B: (1, *)
let b = select!(shape, x = 1).unwrap();
let view_b = b.slice();
let view_b = b.slice().clone();

let selection = base.reify_views(&[view_a, view_b]).unwrap();
let expected = union(
Expand All @@ -2224,10 +2249,10 @@ mod tests {
let base = shape.slice();

let selected1 = select!(shape, y = 0..2).unwrap();
let view1 = selected1.slice();
let view1 = selected1.slice().clone();

let selected2 = select!(shape, y = 1..4).unwrap();
let view2 = selected2.slice();
let view2 = selected2.slice().clone();

let selection = base.reify_views(&[view1, view2]).unwrap();
let expected = union(
Expand Down