Skip to content

Commit b36a3e1

Browse files
actor_mesh: override cast for SlicedActorMesh (#488)
Summary: Pull Request resolved: #488 the sequence of diffs, D77963903, D78030552 and D78104511 have led to the default implementation of the actor mesh trait's `cast` being quite wrong for `SlicedActorMesh`. this is a transitional state; we are working towards a revised design. this diff papers over the cracks making `cast` on slice actor meshes, "do the right thing". Reviewed By: mariusae Differential Revision: D78104511 fbshipit-source-id: 5708762e9afd750b6f5fcf6ff1d2f63d9f7f81bd
1 parent 06d65a5 commit b36a3e1

File tree

1 file changed

+40
-3
lines changed

1 file changed

+40
-3
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#![allow(dead_code)] // until used publically
1010

11+
use std::collections::BTreeSet;
1112
use std::ops::Deref;
1213

1314
use async_trait::async_trait;
@@ -25,7 +26,6 @@ use hyperactor::RemoteMessage;
2526
use hyperactor::Unbind;
2627
use hyperactor::WorldId;
2728
use hyperactor::actor::RemoteActor;
28-
use hyperactor::attrs::Attrs;
2929
use hyperactor::cap;
3030
use hyperactor::mailbox::MailboxSenderError;
3131
use hyperactor::mailbox::PortReceiver;
@@ -36,7 +36,10 @@ use ndslice::Range;
3636
use ndslice::Selection;
3737
use ndslice::Shape;
3838
use ndslice::ShapeError;
39-
use ndslice::Slice;
39+
use ndslice::selection;
40+
use ndslice::selection::EvalOpts;
41+
use ndslice::selection::ReifyView;
42+
use ndslice::selection::normal;
4043
use serde::Deserialize;
4144
use serde::Serialize;
4245
use tokio::sync::mpsc;
@@ -46,7 +49,6 @@ use crate::Mesh;
4649
use crate::comm::multicast::CastMessage;
4750
use crate::comm::multicast::CastMessageEnvelope;
4851
use crate::comm::multicast::Uslice;
49-
use crate::comm::multicast::set_cast_info_on_headers;
5052
use crate::metrics;
5153
use crate::proc_mesh::ProcMesh;
5254
use crate::reference::ActorMeshId;
@@ -327,6 +329,40 @@ impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {
327329
fn name(&self) -> &str {
328330
&self.0.name
329331
}
332+
333+
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
334+
fn cast<M>(&self, sel: Selection, message: M) -> Result<(), CastError>
335+
where
336+
Self::Actor: RemoteHandles<IndexedErasedUnbound<M>>,
337+
M: Castable + RemoteMessage,
338+
{
339+
let base_shape = self.0.shape();
340+
let base_slice = base_shape.slice();
341+
342+
// Casting to `*`?
343+
let selection = if selection::normalize(&sel) == normal::NormalizedSelection::True {
344+
// Reify this view into base.
345+
base_slice.reify_view(self.shape().slice()).unwrap()
346+
} else {
347+
// No, fall back on `of_ranks`.
348+
let ranks = sel
349+
.eval(&EvalOpts::strict(), self.shape().slice())
350+
.unwrap()
351+
.collect::<BTreeSet<_>>();
352+
Selection::of_ranks(base_slice, &ranks).unwrap()
353+
};
354+
355+
// Cast.
356+
actor_mesh_cast::<A, M>(
357+
self.proc_mesh().client(), // send capability
358+
self.id(), // actor mesh id (destination mesh)
359+
base_shape, // actor mesh shape
360+
self.proc_mesh().client().actor_id(), // sender
361+
self.proc_mesh().comm_actor(), // comm actor
362+
selection, // the selected actors
363+
message, // the message
364+
)
365+
}
330366
}
331367

332368
/// The type of error of casting operations.
@@ -485,6 +521,7 @@ mod tests {
485521
use $crate::assign::Ranks;
486522
use $crate::sel_from_shape;
487523
use $crate::sel;
524+
use $crate::comm::multicast::set_cast_info_on_headers;
488525
use $crate::proc_mesh::SharedSpawnable;
489526
use std::collections::VecDeque;
490527
use hyperactor::data::Serialized;

0 commit comments

Comments
 (0)