Skip to content

Commit 64c6c61

Browse files
: actor_mesh: use commactor in cast_slices (#479)
Summary: - in D77963903, routing was updated to use the `(ActorMeshId, ActorId)` pair as the stream key for sequence number tracking. this change allows different sub-slices of a mesh to safely share a common stream identity as long as they belong to the same logical `ActorMeshId`, avoiding issues like message reordering or duplication due to slice mismatch. as a result, this diff removes the now-unnecessary logic that intersected the user-provided selection with a reified view of the actor mesh's slice. correctness depends on all casts to a given `ActorMeshId` being evaluated consistently against that mesh's slice. - previously, `cast_slices` didn’t perform a true cast; it sent messages point-to-point to each rank in the input slices. now, we replace it with a `Selection` constructed via `Selection::of_ranks` and `dsl::union`, and invoke `cast`. the unused `cast_slices` is removed. Reviewed By: mariusae Differential Revision: D77953855
1 parent 5077852 commit 64c6c61

File tree

3 files changed

+42
-29
lines changed

3 files changed

+42
-29
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,6 @@ use ndslice::Selection;
3737
use ndslice::Shape;
3838
use ndslice::ShapeError;
3939
use ndslice::Slice;
40-
use ndslice::dsl;
41-
use ndslice::selection::ReifyView;
4240
use serde::Deserialize;
4341
use serde::Serialize;
4442
use tokio::sync::mpsc;
@@ -62,7 +60,7 @@ pub(crate) fn actor_mesh_cast<M: Castable + Clone, A>(
6260
caps: &impl cap::CanSend,
6361
actor_mesh_id: ActorMeshId,
6462
actor_mesh_shape: &Shape,
65-
proc_mesh_shape: &Shape,
63+
_proc_mesh_shape: &Shape,
6664
actor_name: &str,
6765
sender: &ActorId,
6866
comm_actor_ref: &ActorRef<CommActor>,
@@ -77,6 +75,7 @@ where
7775
"message_variant" => message.arm().unwrap_or_default(),
7876
));
7977

78+
let slice = actor_mesh_shape.slice().clone();
8079
let message = CastMessageEnvelope::new(
8180
actor_mesh_id,
8281
sender.clone(),
@@ -86,29 +85,10 @@ where
8685
None, // TODO: reducer typehash
8786
)?;
8887

89-
// Sub-set the selection to the selection that represents the mesh's view
90-
// of the root mesh. We need to do this because the comm actor uses the
91-
// slice as the stream key; thus different sub-slices will result in potentially
92-
// out of order delivery.
93-
//
94-
// TODO: We should repair this by introducing an explicit stream key, associated
95-
// with the root mesh.
96-
let selection_of_slice = proc_mesh_shape
97-
.slice()
98-
.reify_view(actor_mesh_shape.slice())
99-
.expect("invalid slice");
100-
let selection = dsl::intersection(selection, selection_of_slice);
101-
10288
comm_actor_ref.send(
10389
caps,
10490
CastMessage {
105-
dest: Uslice {
106-
// TODO: currently this slice is being used as the stream key
107-
// in comm actor. We should change it to an explicit id, maintained
108-
// by the root proc mesh.
109-
slice: proc_mesh_shape.slice().clone(),
110-
selection,
111-
},
91+
dest: Uslice { slice, selection },
11292
message,
11393
},
11494
)?;

monarch_extension/src/mesh_controller.rs

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
use std::collections::BTreeMap;
10+
use std::collections::BTreeSet;
1011
use std::collections::HashMap;
1112
use std::collections::HashSet;
1213
use std::collections::VecDeque;
@@ -32,6 +33,7 @@ use hyperactor::PortRef;
3233
use hyperactor::cap::CanSend;
3334
use hyperactor::mailbox::MailboxSenderError;
3435
use hyperactor_mesh::Mesh;
36+
use hyperactor_mesh::actor_mesh::ActorMesh;
3537
use hyperactor_mesh::actor_mesh::RootActorMesh;
3638
use hyperactor_mesh::shared_cell::SharedCell;
3739
use hyperactor_mesh::shared_cell::SharedCellRef;
@@ -53,7 +55,9 @@ use monarch_messages::worker::WorkerMessage;
5355
use monarch_messages::worker::WorkerParams;
5456
use monarch_tensor_worker::AssignRankMessage;
5557
use monarch_tensor_worker::WorkerActor;
58+
use ndslice::Selection;
5659
use ndslice::Slice;
60+
use ndslice::selection;
5761
use pyo3::exceptions::PyValueError;
5862
use pyo3::prelude::*;
5963
use tokio::sync::Mutex;
@@ -250,6 +254,7 @@ impl Invocation {
250254
}
251255
}
252256

257+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
253258
fn add_user(
254259
&mut self,
255260
sender: &impl CanSend,
@@ -287,6 +292,7 @@ impl Invocation {
287292
}
288293
}
289294

295+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
290296
fn complete(&mut self, sender: &impl CanSend) -> Result<(), MailboxSenderError> {
291297
let old_status = std::mem::replace(&mut self.status, Status::Complete {});
292298
match old_status {
@@ -310,6 +316,7 @@ impl Invocation {
310316
/// Incomplete, it may have users that will also become errored. This function
311317
/// will return those users so the error can be propagated. It does not autmoatically
312318
/// propagate the error to avoid deep recursive invocations.
319+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
313320
fn set_exception(
314321
&mut self,
315322
sender: &impl CanSend,
@@ -447,6 +454,7 @@ impl History {
447454
}
448455

449456
/// Add an invocation to the history.
457+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
450458
pub fn add_invocation(
451459
&mut self,
452460
sender: &impl CanSend,
@@ -486,6 +494,7 @@ impl History {
486494

487495
/// Propagate worker error to the invocation with the given Seq. This will also propagate
488496
/// to all seqs that depend on this seq directly or indirectly.
497+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
489498
pub fn propagate_exception(
490499
&mut self,
491500
sender: &impl CanSend,
@@ -539,6 +548,7 @@ impl History {
539548

540549
/// Mark the given rank as completed up to but excluding the given Seq. This will also purge history for
541550
/// any Seqs that are no longer relevant (completed on all ranks).
551+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `MailboxSenderError`.
542552
pub fn rank_completed(
543553
&mut self,
544554
sender: &impl CanSend,
@@ -623,6 +633,7 @@ impl MeshControllerActor {
623633
fn workers(&self) -> SharedCellRef<RootActorMesh<'static, WorkerActor>> {
624634
self.workers.as_ref().unwrap().borrow().unwrap()
625635
}
636+
626637
fn handle_debug(
627638
&mut self,
628639
this: &Context<Self>,
@@ -741,7 +752,8 @@ impl Actor for MeshControllerActor {
741752
workers
742753
.borrow()
743754
.unwrap()
744-
.cast_slices(vec![slice.clone()], AssignRankMessage::AssignRank())?;
755+
.cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?;
756+
745757
self.workers = Some(workers);
746758
Ok(())
747759
}
@@ -796,7 +808,19 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
796808
) -> anyhow::Result<()> {
797809
match message {
798810
ClientToControllerMessage::Send { slices, message } => {
799-
self.workers().cast_slices(slices, message)?;
811+
let selection = slices
812+
.iter()
813+
.map(|slice| {
814+
Selection::of_ranks(
815+
self.workers().shape().slice(),
816+
&slice.iter().collect::<BTreeSet<usize>>(),
817+
)
818+
})
819+
.collect::<Result<Vec<_>, _>>()?
820+
.into_iter()
821+
.reduce(selection::dsl::union)
822+
.unwrap_or_else(selection::dsl::false_);
823+
self.workers().cast(selection, message.clone())?;
800824
}
801825
ClientToControllerMessage::Node {
802826
seq,
@@ -812,9 +836,8 @@ impl Handler<ClientToControllerMessage> for MeshControllerActor {
812836
self.history.drop_refs(refs);
813837
}
814838
ClientToControllerMessage::SyncAtExit { port } => {
815-
let all_ranks = vec![self.workers().shape().slice().clone()];
816-
self.workers().cast_slices(
817-
all_ranks,
839+
self.workers().cast(
840+
selection::dsl::true_(),
818841
WorkerMessage::RequestStatus {
819842
seq: self.history.seq_lower_bound,
820843
controller: false,

monarch_tensor_worker/src/lib.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,17 @@ pub enum AssignRankMessage {
290290
}
291291

292292
#[async_trait]
293-
#[forward(WorkerMessage)]
293+
impl Handler<WorkerMessage> for WorkerActor {
294+
async fn handle(
295+
&mut self,
296+
cx: &hyperactor::Context<Self>,
297+
message: WorkerMessage,
298+
) -> anyhow::Result<()> {
299+
<Self as WorkerMessageHandler>::handle(self, cx, message).await
300+
}
301+
}
302+
303+
#[async_trait]
294304
impl WorkerMessageHandler for WorkerActor {
295305
async fn backend_network_init(
296306
&mut self,

0 commit comments

Comments
 (0)