Skip to content

: actor_mesh: remove redundant params #481

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 22 additions & 45 deletions hyperactor_mesh/src/actor_mesh.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@ use ndslice::Selection;
use ndslice::Shape;
use ndslice::ShapeError;
use ndslice::Slice;
use ndslice::dsl;
use ndslice::selection::ReifyView;
use serde::Deserialize;
use serde::Serialize;
use tokio::sync::mpsc;
Expand All @@ -47,7 +45,6 @@ use crate::CommActor;
use crate::Mesh;
use crate::comm::multicast::CastMessage;
use crate::comm::multicast::CastMessageEnvelope;
use crate::comm::multicast::DestinationPort;
use crate::comm::multicast::Uslice;
use crate::comm::multicast::set_cast_info_on_headers;
use crate::metrics;
Expand All @@ -56,59 +53,40 @@ use crate::reference::ActorMeshId;
use crate::reference::ActorMeshRef;
use crate::reference::ProcMeshId;

/// Common implementation for ActorMeshes and ActorMeshRefs to cast an [`M`]-typed message
/// Common implementation for `ActorMesh`s and `ActorMeshRef`s to cast
/// an `M`-typed message
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub(crate) fn actor_mesh_cast<M: Castable + Clone, A>(
pub(crate) fn actor_mesh_cast<A, M>(
caps: &impl cap::CanSend,
actor_mesh_id: ActorMeshId,
actor_mesh_shape: &Shape,
proc_mesh_shape: &Shape,
actor_name: &str,
sender: &ActorId,
comm_actor_ref: &ActorRef<CommActor>,
selection: Selection,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
"message_type" => M::typename(),
"message_variant" => message.arm().unwrap_or_default(),
));

let message = CastMessageEnvelope::new(
let slice = actor_mesh_shape.slice().clone();
let message = CastMessageEnvelope::new::<A, M>(
actor_mesh_id,
sender.clone(),
DestinationPort::new::<A, M>(actor_name.to_string()),
actor_mesh_shape.clone(),
message,
None, // TODO: reducer typehash
)?;

// Sub-set the selection to the selection that represents the mesh's view
// of the root mesh. We need to do this because the comm actor uses the
// slice as the stream key; thus different sub-slices will result in potentially
// out of order delivery.
//
// TODO: We should repair this by introducing an explicit stream key, associated
// with the root mesh.
let selection_of_slice = proc_mesh_shape
.slice()
.reify_view(actor_mesh_shape.slice())
.expect("invalid slice");
let selection = dsl::intersection(selection, selection_of_slice);

comm_actor_ref.send(
caps,
CastMessage {
dest: Uslice {
// TODO: currently this slice is being used as the stream key
// in comm actor. We should change it to an explicit id, maintained
// by the root proc mesh.
slice: proc_mesh_shape.slice().clone(),
selection,
},
dest: Uslice { slice, selection },
message,
},
)?;
Expand All @@ -121,30 +99,29 @@ pub trait ActorMesh: Mesh<Id = ActorMeshId> {
/// The type of actor in the mesh.
type Actor: RemoteActor;

/// Cast an [`M`]-typed message to the ranks selected by `sel`
/// in this ActorMesh.
/// Cast an `M`-typed message to the ranks selected by `sel` in
/// this ActorMesh.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
fn cast<M: Castable + Clone>(&self, selection: Selection, message: M) -> Result<(), CastError>
fn cast<M>(&self, selection: Selection, message: M) -> Result<(), CastError>
where
Self::Actor: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
actor_mesh_cast::<M, Self::Actor>(
self.proc_mesh().client(),
self.id(),
self.shape(),
self.proc_mesh().shape(),
self.name(),
self.proc_mesh().client().actor_id(),
self.proc_mesh().comm_actor(),
selection,
message,
actor_mesh_cast::<Self::Actor, M>(
self.proc_mesh().client(), // send capability
self.id(), // actor mesh id (destination mesh)
self.shape(), // actor mesh shape
self.proc_mesh().client().actor_id(), // sender
self.proc_mesh().comm_actor(), // comm actor
selection, // the selected actors
message, // the message
)
}

/// The ProcMesh on top of which this actor mesh is spawned.
fn proc_mesh(&self) -> &ProcMesh;

/// The name global name of actors in this mesh.
/// The name given to the actors in this mesh.
fn name(&self) -> &str;

fn world_id(&self) -> &WorldId {
Expand Down
14 changes: 9 additions & 5 deletions hyperactor_mesh/src/comm/multicast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,19 +62,23 @@ pub struct CastMessageEnvelope {

impl CastMessageEnvelope {
/// Create a new CastMessageEnvelope.
pub fn new<T: Castable + Serialize + Named>(
pub fn new<A, M>(
actor_mesh_id: ActorMeshId,
sender: ActorId,
dest_port: DestinationPort,
shape: Shape,
message: T,
message: M,
reducer_typehash: Option<u64>,
) -> Result<Self, anyhow::Error> {
) -> Result<Self, anyhow::Error>
where
A: RemoteActor + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
let data = ErasedUnbound::try_from_message(message)?;
let actor_name = actor_mesh_id.1.to_string();
Ok(Self {
actor_mesh_id,
sender,
dest_port,
dest_port: DestinationPort::new::<A, M>(actor_name),
data,
reducer_typehash,
shape,
Expand Down
10 changes: 5 additions & 5 deletions hyperactor_mesh/src/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use std::marker::PhantomData;
use hyperactor::ActorRef;
use hyperactor::Named;
use hyperactor::RemoteHandles;
use hyperactor::RemoteMessage;
use hyperactor::actor::RemoteActor;
use hyperactor::cap;
use hyperactor::message::Castable;
Expand Down Expand Up @@ -54,7 +55,7 @@ macro_rules! mesh_id {
)]
pub struct ProcMeshId(pub String);

/// Actor Mesh ID. Tuple of the ProcMesh ID and Actor Mesh ID.
/// Actor Mesh ID. Tuple of the ProcMesh ID and actor name.
#[derive(
Debug,
Serialize,
Expand Down Expand Up @@ -123,21 +124,20 @@ impl<A: RemoteActor> ActorMeshRef<A> {
/// Cast an [`M`]-typed message to the ranks selected by `sel`
/// in this ActorMesh.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
pub fn cast<M: Castable + Clone>(
pub fn cast<M>(
&self,
caps: &(impl cap::CanSend + cap::CanOpenPort),
selection: Selection,
message: M,
) -> Result<(), CastError>
where
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
M: Castable + RemoteMessage,
{
actor_mesh_cast::<M, A>(
actor_mesh_cast::<A, M>(
caps,
self.mesh_id.clone(),
self.shape(),
self.proc_mesh_shape(),
self.name(),
caps.mailbox().actor_id(),
&self.comm_actor_ref,
selection,
Expand Down
33 changes: 28 additions & 5 deletions monarch_extension/src/mesh_controller.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
*/

use std::collections::BTreeMap;
use std::collections::BTreeSet;
use std::collections::HashMap;
use std::collections::HashSet;
use std::collections::VecDeque;
Expand All @@ -32,6 +33,7 @@ use hyperactor::PortRef;
use hyperactor::cap::CanSend;
use hyperactor::mailbox::MailboxSenderError;
use hyperactor_mesh::Mesh;
use hyperactor_mesh::actor_mesh::ActorMesh;
use hyperactor_mesh::actor_mesh::RootActorMesh;
use hyperactor_mesh::shared_cell::SharedCell;
use hyperactor_mesh::shared_cell::SharedCellRef;
Expand All @@ -53,7 +55,9 @@ use monarch_messages::worker::WorkerMessage;
use monarch_messages::worker::WorkerParams;
use monarch_tensor_worker::AssignRankMessage;
use monarch_tensor_worker::WorkerActor;
use ndslice::Selection;
use ndslice::Slice;
use ndslice::selection;
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use tokio::sync::Mutex;
Expand Down Expand Up @@ -250,6 +254,7 @@ impl Invocation {
}
}

#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
fn add_user(
&mut self,
sender: &impl CanSend,
Expand Down Expand Up @@ -287,6 +292,7 @@ impl Invocation {
}
}

#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
fn complete(&mut self, sender: &impl CanSend) -> Result<(), MailboxSenderError> {
let old_status = std::mem::replace(&mut self.status, Status::Complete {});
match old_status {
Expand All @@ -310,6 +316,7 @@ impl Invocation {
/// Incomplete, it may have users that will also become errored. This function
/// will return those users so the error can be propagated. It does not autmoatically
/// propagate the error to avoid deep recursive invocations.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
fn set_exception(
&mut self,
sender: &impl CanSend,
Expand Down Expand Up @@ -447,6 +454,7 @@ impl History {
}

/// Add an invocation to the history.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
pub fn add_invocation(
&mut self,
sender: &impl CanSend,
Expand Down Expand Up @@ -486,6 +494,7 @@ impl History {

/// Propagate worker error to the invocation with the given Seq. This will also propagate
/// to all seqs that depend on this seq directly or indirectly.
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
pub fn propagate_exception(
&mut self,
sender: &impl CanSend,
Expand Down Expand Up @@ -539,6 +548,7 @@ impl History {

/// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for
/// any Seqs that are no longer relevant (completed on all ranks).
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
pub fn rank_completed(
&mut self,
sender: &impl CanSend,
Expand Down Expand Up @@ -623,6 +633,7 @@ impl MeshControllerActor {
fn workers(&self) -> SharedCellRef<RootActorMesh<'static, WorkerActor>> {
self.workers.as_ref().unwrap().borrow().unwrap()
}

fn handle_debug(
&mut self,
this: &Context<Self>,
Expand Down Expand Up @@ -741,7 +752,8 @@ impl Actor for MeshControllerActor {
workers
.borrow()
.unwrap()
.cast_slices(vec![slice.clone()], AssignRankMessage::AssignRank())?;
.cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?;

self.workers = Some(workers);
Ok(())
}
Expand Down Expand Up @@ -796,7 +808,19 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
) -> anyhow::Result<()> {
match message {
ClientToControllerMessage::Send { slices, message } => {
self.workers().cast_slices(slices, message)?;
let selection = slices
.iter()
.map(|slice| {
Selection::of_ranks(
self.workers().shape().slice(),
&slice.iter().collect::<BTreeSet<usize>>(),
)
})
.collect::<Result<Vec<_>, _>>()?
.into_iter()
.reduce(selection::dsl::union)
.unwrap_or_else(selection::dsl::false_);
self.workers().cast(selection, message.clone())?;
}
ClientToControllerMessage::Node {
seq,
Expand All @@ -812,9 +836,8 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
self.history.drop_refs(refs);
}
ClientToControllerMessage::SyncAtExit { port } => {
let all_ranks = vec![self.workers().shape().slice().clone()];
self.workers().cast_slices(
all_ranks,
self.workers().cast(
selection::dsl::true_(),
WorkerMessage::RequestStatus {
seq: self.history.seq_lower_bound,
controller: false,
Expand Down
12 changes: 11 additions & 1 deletion monarch_tensor_worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,17 @@ pub enum AssignRankMessage {
}

#[async_trait]
#[forward(WorkerMessage)]
impl Handler<WorkerMessage> for WorkerActor {
async fn handle(
&mut self,
cx: &hyperactor::Context<Self>,
message: WorkerMessage,
) -> anyhow::Result<()> {
<Self as WorkerMessageHandler>::handle(self, cx, message).await
}
}

#[async_trait]
impl WorkerMessageHandler for WorkerActor {
async fn backend_network_init(
&mut self,
Expand Down