Skip to content

Commit 7a8fb6a

Browse files
actor_mesh: use commactor in cast_slices (#461)
Summary: Pull Request resolved: #461 refactor `cast_slices`. the old approach iterated over the input slices and for each rank in each slice, sent a message. now, call `reify_views` on the input slices to produce a single `Selection` (union of ranges). then, cast the message to the selected ranks via the comm actor. Differential Revision: D77935610
1 parent bbcbcdf commit 7a8fb6a

File tree

1 file changed

+14
-21
lines changed

1 file changed

+14
-21
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use hyperactor::RemoteMessage;
2525
use hyperactor::Unbind;
2626
use hyperactor::WorldId;
2727
use hyperactor::actor::RemoteActor;
28-
use hyperactor::attrs::Attrs;
2928
use hyperactor::cap;
3029
use hyperactor::mailbox::MailboxSenderError;
3130
use hyperactor::mailbox::PortReceiver;
@@ -49,7 +48,6 @@ use crate::comm::multicast::CastMessage;
4948
use crate::comm::multicast::CastMessageEnvelope;
5049
use crate::comm::multicast::DestinationPort;
5150
use crate::comm::multicast::Uslice;
52-
use crate::comm::multicast::set_cast_info_on_headers;
5351
use crate::metrics;
5452
use crate::proc_mesh::ProcMesh;
5553
use crate::reference::ActorMeshId;
@@ -237,32 +235,26 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
237235
/// Until the selection logic is more powerful, we need a way to
238236
/// replicate the send patterns that the worker actor mesh actually does.
239237
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
240-
pub fn cast_slices<M: RemoteMessage + Clone>(
241-
&self,
242-
sel: Vec<Slice>,
243-
message: M,
244-
) -> Result<(), CastError>
238+
pub fn cast_slices<M>(&self, sel: Vec<Slice>, message: M) -> Result<(), CastError>
245239
where
240+
M: Castable + Clone,
246241
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
247242
{
248243
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
249244
"message_type" => M::typename(),
250245
"message_variant" => message.arm().unwrap_or_default(),
251246
));
252-
for ref slice in sel {
253-
for rank in slice.iter() {
254-
let mut headers = Attrs::new();
255-
set_cast_info_on_headers(
256-
&mut headers,
257-
rank,
258-
self.shape().clone(),
259-
self.proc_mesh.client().actor_id().clone(),
260-
);
261-
self.ranks[rank]
262-
.send_with_headers(self.proc_mesh.client(), headers, message.clone())
263-
.map_err(|err| CastError::MailboxSenderError(rank, err))?;
264-
}
265-
}
247+
248+
let slices: &[&Slice] = &sel.iter().collect::<Vec<_>>();
249+
let selection = self
250+
.proc_mesh
251+
.shape()
252+
.slice()
253+
.reify_views(slices)
254+
.expect("invalid slices");
255+
256+
self.cast(selection, message)?;
257+
266258
Ok(())
267259
}
268260

@@ -528,6 +520,7 @@ mod tests {
528520
use $crate::sel_from_shape;
529521
use $crate::sel;
530522
use $crate::proc_mesh::SharedSpawnable;
523+
use $crate::comm::multicast::set_cast_info_on_headers;
531524
use std::collections::VecDeque;
532525
use hyperactor::data::Serialized;
533526

0 commit comments

Comments
 (0)