Skip to content

Commit 022e2cc

Browse files
mariusaefacebook-github-bot
authored andcommitted
mesh: streams are keyed on (ActorMeshId, Sender) (#471)
Summary: Pull Request resolved: #471 This change aligns all of the stream keys along the comm actor delivery path. We formalize streams to be per-(ActorMeshId, Sender), and enforce this everywhere. After this change, we can begin to simplify the comm actor delivery code. with shayne-fletcher Reviewed By: shayne-fletcher Differential Revision: D77963903 fbshipit-source-id: d906e5384b3cace8d900c0b1acf7c9509f249491
1 parent ffb9824 commit 022e2cc

File tree

7 files changed

+61
-6
lines changed

7 files changed

+61
-6
lines changed

controller/src/lib.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ use hyperactor_mesh::comm::multicast::CastMessage;
3939
use hyperactor_mesh::comm::multicast::CastMessageEnvelope;
4040
use hyperactor_mesh::comm::multicast::DestinationPort;
4141
use hyperactor_mesh::comm::multicast::Uslice;
42+
use hyperactor_mesh::reference::ActorMeshId;
43+
use hyperactor_mesh::reference::ProcMeshId;
4244
use hyperactor_multiprocess::proc_actor::ProcActor;
4345
use hyperactor_multiprocess::proc_actor::spawn;
4446
use hyperactor_multiprocess::supervision::WorldSupervisionMessageClient;
@@ -422,7 +424,12 @@ impl ControllerMessageHandler for ControllerActor {
422424
dsl::union(sel, slice_to_selection(slice))
423425
}),
424426
};
427+
425428
let message = CastMessageEnvelope::from_serialized(
429+
ActorMeshId(
430+
ProcMeshId(self.worker_gang_ref.gang_id().world_id().to_string()),
431+
self.worker_gang_ref.gang_id().name().to_string(),
432+
),
426433
cx.self_id().clone(),
427434
DestinationPort::new::<WorkerActor, WorkerMessage>(
428435
// This is awkward, but goes away entirely with meshes.

hyperactor_mesh/src/actor_mesh.rs

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ use crate::reference::ProcMeshId;
6060
#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `CastError`.
6161
pub(crate) fn actor_mesh_cast<M: Castable + Clone, A>(
6262
caps: &impl cap::CanSend,
63+
actor_mesh_id: ActorMeshId,
6364
actor_mesh_shape: &Shape,
6465
proc_mesh_shape: &Shape,
6566
actor_name: &str,
@@ -77,6 +78,7 @@ where
7778
));
7879

7980
let message = CastMessageEnvelope::new(
81+
actor_mesh_id,
8082
sender.clone(),
8183
DestinationPort::new::<A, M>(actor_name.to_string()),
8284
actor_mesh_shape.clone(),
@@ -115,7 +117,7 @@ where
115117
}
116118

117119
/// A mesh of actors, all of which reside on the same [`ProcMesh`].
118-
pub trait ActorMesh: Mesh {
120+
pub trait ActorMesh: Mesh<Id = ActorMeshId> {
119121
/// The type of actor in the mesh.
120122
type Actor: RemoteActor;
121123

@@ -128,6 +130,7 @@ pub trait ActorMesh: Mesh {
128130
{
129131
actor_mesh_cast::<M, Self::Actor>(
130132
self.proc_mesh().client(),
133+
self.id(),
131134
self.shape(),
132135
self.proc_mesh().shape(),
133136
self.name(),
@@ -288,6 +291,7 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> {
288291
#[async_trait]
289292
impl<'a, A: RemoteActor> Mesh for RootActorMesh<'a, A> {
290293
type Node = ActorRef<A>;
294+
type Id = ActorMeshId;
291295
type Sliced<'b>
292296
= SlicedActorMesh<'b, A>
293297
where
@@ -308,6 +312,10 @@ impl<'a, A: RemoteActor> Mesh for RootActorMesh<'a, A> {
308312
fn get(&self, rank: usize) -> Option<ActorRef<A>> {
309313
self.ranks.get(rank).cloned()
310314
}
315+
316+
fn id(&self) -> Self::Id {
317+
ActorMeshId(self.proc_mesh.id(), self.name.clone())
318+
}
311319
}
312320

313321
impl<A: RemoteActor> ActorMesh for RootActorMesh<'_, A> {
@@ -337,6 +345,7 @@ impl<'a, A: RemoteActor> SlicedActorMesh<'a, A> {
337345
#[async_trait]
338346
impl<A: RemoteActor> Mesh for SlicedActorMesh<'_, A> {
339347
type Node = ActorRef<A>;
348+
type Id = ActorMeshId;
340349
type Sliced<'b>
341350
= SlicedActorMesh<'b, A>
342351
where
@@ -357,6 +366,10 @@ impl<A: RemoteActor> Mesh for SlicedActorMesh<'_, A> {
357366
fn get(&self, _index: usize) -> Option<ActorRef<A>> {
358367
unimplemented!()
359368
}
369+
370+
fn id(&self) -> Self::Id {
371+
self.0.id()
372+
}
360373
}
361374

362375
impl<A: RemoteActor> ActorMesh for SlicedActorMesh<'_, A> {

hyperactor_mesh/src/comm.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
use crate::comm::multicast::CAST_ORIGINATING_SENDER;
10+
use crate::reference::ActorMeshId;
1011
pub mod multicast;
1112

1213
use std::cmp::Ordering;
@@ -84,10 +85,10 @@ struct ReceiveState {
8485
],
8586
)]
8687
pub struct CommActor {
87-
/// Each world will use its own seq num from this caster.
88-
send_seq: HashMap<Slice, usize>,
88+
/// Sequence numbers are maintained for each (actor mesh id, sender).
89+
send_seq: HashMap<(ActorMeshId, ActorId), usize>,
8990
/// Each sender is a unique stream.
90-
recv_state: HashMap<ActorId, ReceiveState>,
91+
recv_state: HashMap<(ActorMeshId, ActorId), ReceiveState>,
9192

9293
/// The comm actor's mode.
9394
mode: CommActorMode,
@@ -313,7 +314,7 @@ impl Handler<CastMessage> for CommActor {
313314
let rank = frame.slice.location(&frame.here)?;
314315
let seq = self
315316
.send_seq
316-
.entry(frame.slice.as_ref().clone())
317+
.entry(cast_message.message.stream_key())
317318
.or_default();
318319
let last_seq = *seq;
319320
*seq += 1;
@@ -351,7 +352,7 @@ impl Handler<ForwardMessage> for CommActor {
351352
panic!("Choice encountered in CommActor routing")
352353
})?;
353354

354-
let recv_state = self.recv_state.entry(sender.clone()).or_default();
355+
let recv_state = self.recv_state.entry(message.stream_key()).or_default();
355356
match recv_state.seq.cmp(&last_seq) {
356357
// We got the expected next message to deliver to this host.
357358
Ordering::Equal => {

hyperactor_mesh/src/comm/multicast.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ use ndslice::selection::routing::RoutingFrame;
2828
use serde::Deserialize;
2929
use serde::Serialize;
3030

31+
use crate::reference::ActorMeshId;
32+
3133
/// A union of slices that can be used to represent arbitrary subset of
3234
/// ranks in a gang. It is represented by a Slice together with a Selection.
3335
/// This is used to define the destination of a cast message or the source of
@@ -43,6 +45,8 @@ pub struct Uslice {
4345
/// An envelope that carries a message destined to a group of actors.
4446
#[derive(Debug, Serialize, Deserialize, Clone, PartialEq, Named)]
4547
pub struct CastMessageEnvelope {
48+
/// The destination actor mesh id.
49+
actor_mesh_id: ActorMeshId,
4650
/// The sender of this message.
4751
sender: ActorId,
4852
/// The destination port of the message. It could match multiple actors with
@@ -59,6 +63,7 @@ pub struct CastMessageEnvelope {
5963
impl CastMessageEnvelope {
6064
/// Create a new CastMessageEnvelope.
6165
pub fn new<T: Castable + Serialize + Named>(
66+
actor_mesh_id: ActorMeshId,
6267
sender: ActorId,
6368
dest_port: DestinationPort,
6469
shape: Shape,
@@ -67,6 +72,7 @@ impl CastMessageEnvelope {
6772
) -> Result<Self, anyhow::Error> {
6873
let data = ErasedUnbound::try_from_message(message)?;
6974
Ok(Self {
75+
actor_mesh_id,
7076
sender,
7177
dest_port,
7278
data,
@@ -79,12 +85,14 @@ impl CastMessageEnvelope {
7985
/// when the message do not contain reply ports. Or it does but you are okay
8086
/// with the destination actors reply to the client actor directly.
8187
pub fn from_serialized(
88+
actor_mesh_id: ActorMeshId,
8289
sender: ActorId,
8390
dest_port: DestinationPort,
8491
shape: Shape,
8592
data: Serialized,
8693
) -> Self {
8794
Self {
95+
actor_mesh_id,
8896
sender,
8997
dest_port,
9098
data: ErasedUnbound::new(data),
@@ -112,6 +120,13 @@ impl CastMessageEnvelope {
112120
pub(crate) fn shape(&self) -> &Shape {
113121
&self.shape
114122
}
123+
124+
/// The unique key used to indicate the stream to which to deliver this message.
125+
/// Concretely, the comm actors along the path should use this key to manage
126+
/// sequence numbers and reorder buffers.
127+
pub(crate) fn stream_key(&self) -> (ActorMeshId, ActorId) {
128+
(self.actor_mesh_id.clone(), self.sender.clone())
129+
}
115130
}
116131

117132
/// Destination port id of a message. It is a `PortId` with the rank masked out,

hyperactor_mesh/src/mesh.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
*/
88

99
use async_trait::async_trait;
10+
use hyperactor::RemoteMessage;
1011
use ndslice::Range;
1112
use ndslice::Shape;
1213
use ndslice::ShapeError;
@@ -18,6 +19,9 @@ pub trait Mesh {
1819
/// The type of the node contained in the mesh.
1920
type Node;
2021

22+
/// The type of identifiers for this mesh.
23+
type Id: RemoteMessage;
24+
2125
/// The type of a slice of this mesh. Slices should not outlive their
2226
/// parent mesh.
2327
type Sliced<'a>: Mesh<Node = Self::Node> + 'a
@@ -43,6 +47,9 @@ pub trait Mesh {
4347
slice_iter: self.shape().slice().iter(),
4448
}
4549
}
50+
51+
/// The global identifier for this mesh.
52+
fn id(&self) -> Self::Id;
4653
}
4754

4855
/// An iterator over the nodes of a mesh.

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use crate::assign::Ranks;
5454
use crate::comm::CommActorMode;
5555
use crate::proc_mesh::mesh_agent::MeshAgent;
5656
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
57+
use crate::reference::ProcMeshId;
5758

5859
pub mod mesh_agent;
5960

@@ -574,6 +575,7 @@ impl<D: Deref<Target = ProcMesh> + Send + Sync + 'static> SharedSpawnable for D
574575
#[async_trait]
575576
impl Mesh for ProcMesh {
576577
type Node = ProcId;
578+
type Id = ProcMeshId;
577579
type Sliced<'a> = SlicedProcMesh<'a>;
578580

579581
fn shape(&self) -> &Shape {
@@ -591,6 +593,10 @@ impl Mesh for ProcMesh {
591593
fn get(&self, rank: usize) -> Option<ProcId> {
592594
Some(self.ranks[rank].0.clone())
593595
}
596+
597+
fn id(&self) -> Self::Id {
598+
ProcMeshId(self.world_id().name().to_string())
599+
}
594600
}
595601

596602
impl fmt::Display for ProcMesh {
@@ -616,6 +622,7 @@ pub struct SlicedProcMesh<'a>(&'a ProcMesh, Shape);
616622
#[async_trait]
617623
impl Mesh for SlicedProcMesh<'_> {
618624
type Node = ProcId;
625+
type Id = ProcMeshId;
619626
type Sliced<'b>
620627
= SlicedProcMesh<'b>
621628
where
@@ -636,6 +643,10 @@ impl Mesh for SlicedProcMesh<'_> {
636643
fn get(&self, _index: usize) -> Option<ProcId> {
637644
unimplemented!()
638645
}
646+
647+
fn id(&self) -> Self::Id {
648+
self.0.id()
649+
}
639650
}
640651

641652
#[cfg(test)]

hyperactor_mesh/src/reference.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ impl<A: RemoteActor> ActorMeshRef<A> {
134134
{
135135
actor_mesh_cast::<M, A>(
136136
caps,
137+
self.mesh_id.clone(),
137138
self.shape(),
138139
self.proc_mesh_shape(),
139140
self.name(),

0 commit comments

Comments
 (0)