Skip to content

Commit 58978af

Browse files
: actor_mesh: use commactor in cast_slices
Summary: refactor `cast_slices`. the old approach iterated over the input slices and for each rank in each slice, sent a message. now, call `reify_views` on the input slices to produce a single `Selection` (union of ranges). then, cast the message to the selected ranks via the comm actor. Differential Revision: D77935610
1 parent 92618a7 commit 58978af

File tree

1 file changed

+23
-21
lines changed

1 file changed

+23
-21
lines changed

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use hyperactor::RemoteMessage;
2525
use hyperactor::Unbind;
2626
use hyperactor::WorldId;
2727
use hyperactor::actor::RemoteActor;
28-
use hyperactor::attrs::Attrs;
2928
use hyperactor::cap;
3029
use hyperactor::mailbox::MailboxSenderError;
3130
use hyperactor::mailbox::PortReceiver;
@@ -49,7 +48,6 @@ use crate::comm::multicast::CastMessage;
4948
use crate::comm::multicast::CastMessageEnvelope;
5049
use crate::comm::multicast::DestinationPort;
5150
use crate::comm::multicast::Uslice;
52-
use crate::comm::multicast::set_cast_info_on_headers;
5351
use crate::metrics;
5452
use crate::proc_mesh::ProcMesh;
5553
use crate::reference::ActorMeshId;
@@ -237,32 +235,35 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
237235
/// Until the selection logic is more powerful, we need a way to
238236
/// replicate the send patterns that the worker actor mesh actually does.
239237
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
240-
pub fn cast_slices<M: RemoteMessage + Clone>(
241-
&self,
242-
sel: Vec<Slice>,
243-
message: M,
244-
) -> Result<(), CastError>
238+
pub fn cast_slices<M>(&self, sel: Vec<Slice>, message: M) -> Result<(), CastError>
245239
where
240+
M: Castable + Clone,
246241
A: RemoteHandles<M> + RemoteHandles<IndexedErasedUnbound<M>>,
247242
{
248243
let _ = metrics::ACTOR_MESH_CAST_DURATION.start(hyperactor::kv_pairs!(
249244
"message_type" => M::typename(),
250245
"message_variant" => message.arm().unwrap_or_default(),
251246
));
252-
for ref slice in sel {
253-
for rank in slice.iter() {
254-
let mut headers = Attrs::new();
255-
set_cast_info_on_headers(
256-
&mut headers,
257-
rank,
258-
self.shape().clone(),
259-
self.proc_mesh.client().actor_id().clone(),
260-
);
261-
self.ranks[rank]
262-
.send_with_headers(self.proc_mesh.client(), headers, message.clone())
263-
.map_err(|err| CastError::MailboxSenderError(rank, err))?;
264-
}
265-
}
247+
248+
let client = self.proc_mesh.client();
249+
let snd = client.actor_id().clone();
250+
let dst = DestinationPort::new::<A, M>(self.name().to_string());
251+
let shape = self.proc_mesh.shape().clone();
252+
let base_slice = self.proc_mesh.shape().slice().clone();
253+
let slices: &[&Slice] = &sel.iter().collect::<Vec<_>>();
254+
let message = CastMessageEnvelope::new(snd, dst, shape, message, None)?;
255+
256+
self.proc_mesh.comm_actor().send(
257+
client,
258+
CastMessage {
259+
dest: Uslice {
260+
slice: base_slice.clone(),
261+
selection: base_slice.reify_views(slices).expect("invalid slices"),
262+
},
263+
message,
264+
},
265+
)?;
266+
266267
Ok(())
267268
}
268269

@@ -528,6 +529,7 @@ mod tests {
528529
use $crate::sel_from_shape;
529530
use $crate::sel;
530531
use $crate::proc_mesh::SharedSpawnable;
532+
use $crate::comm::multicast::set_cast_info_on_headers;
531533
use std::collections::VecDeque;
532534
use hyperactor::data::Serialized;
533535

0 commit comments

Comments
 (0)