Skip to content

Commit 5e1f4fe

Browse files
highkerfacebook-github-bot
authored andcommitted
(2/n) get ready to optionally spin up state actor on the client
Differential Revision: D77848229
1 parent 512de73 commit 5e1f4fe

File tree

4 files changed

+257
-7
lines changed

4 files changed

+257
-7
lines changed

hyperactor_mesh/src/alloc.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ use serde::Deserialize;
3535
use serde::Serialize;
3636

3737
use crate::alloc::test_utils::MockAllocWrapper;
38+
use crate::log_source::LogSource;
3839
use crate::proc_mesh::mesh_agent::MeshAgent;
3940

4041
/// Errors that occur during allocation operations.
@@ -218,6 +219,20 @@ pub trait Alloc {
218219
/// The channel transport used the procs in this alloc.
219220
fn transport(&self) -> ChannelTransport;
220221

222+
/// The log source for this alloc.
223+
/// It allows remote processes to stream stdout and stderr back to the client.
224+
/// A client can connect to the log source to obtain the streamed logs.
225+
/// A log source is allocation specific. Each allocator can decide how to stream the logs back.
226+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
227+
// TODO: this should be implemented based on different allocators.
228+
// Having this temporarily here so that the client can connect to the log source.
229+
// But the client will not get anything.
230+
// The following diffs will gradually implement this for different allocators.
231+
LogSource::new_with_local_actor()
232+
.await
233+
.map_err(AllocatorError::from)
234+
}
235+
221236
/// Stop this alloc, shutting down all of its procs. A clean
222237
/// shutdown should result in Stop events from all allocs,
223238
/// followed by the end of the event stream.

hyperactor_mesh/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod bootstrap;
1919
pub mod code_sync;
2020
pub mod comm;
2121
pub mod connect;
22+
pub mod log_source;
2223
pub mod mesh;
2324
pub mod mesh_selection;
2425
mod metrics;

hyperactor_mesh/src/log_source.rs

Lines changed: 234 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,234 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::fmt;
10+
use std::str::FromStr;
11+
12+
use hyperactor::ActorId;
13+
use hyperactor::ActorRef;
14+
use hyperactor::ProcId;
15+
use hyperactor::WorldId;
16+
use hyperactor::channel;
17+
use hyperactor::channel::ChannelAddr;
18+
use hyperactor::channel::ChannelTransport;
19+
use hyperactor::mailbox;
20+
use hyperactor::mailbox::BoxedMailboxSender;
21+
use hyperactor::mailbox::DialMailboxRouter;
22+
use hyperactor::mailbox::MailboxServer;
23+
use hyperactor::proc::Proc;
24+
use hyperactor_state::state_actor::StateActor;
25+
26+
use crate::shortuuid::ShortUuid;
27+
28+
/// The source of the log so that the remote process can stream stdout and stderr to.
29+
/// A log source is allocation specific. Each allocator can decide how to stream the logs back.
30+
/// It holds a reference or the lifecycle of a state actor that collects all the logs from processes.
31+
#[derive(Clone, Debug)]
32+
pub struct LogSource {
33+
// Optionally hold the lifecycle of the state actor
34+
_state_proc: Option<Proc>,
35+
state_proc_addr: ChannelAddr,
36+
state_actor: ActorRef<StateActor>,
37+
}
38+
39+
#[derive(Clone, Debug)]
40+
pub struct StateServerInfo {
41+
pub state_proc_addr: ChannelAddr,
42+
pub state_actor_id: ActorId,
43+
}
44+
45+
impl fmt::Display for StateServerInfo {
46+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47+
write!(f, "{},{}", self.state_proc_addr, self.state_actor_id)
48+
}
49+
}
50+
51+
impl FromStr for StateServerInfo {
52+
type Err = anyhow::Error;
53+
54+
fn from_str(state_server_info: &str) -> Result<Self, Self::Err> {
55+
match state_server_info.split_once(",") {
56+
Some((addr, id)) => {
57+
let state_proc_addr: ChannelAddr = addr.parse()?;
58+
let state_actor_id: ActorId = id.parse()?;
59+
Ok(StateServerInfo {
60+
state_proc_addr,
61+
state_actor_id,
62+
})
63+
}
64+
_ => Err(anyhow::anyhow!(
65+
"unrecognized state server info: {state_server_info}"
66+
)),
67+
}
68+
}
69+
}
70+
71+
impl LogSource {
72+
/// Spin up the state actor locally to receive the remote logs.
73+
pub async fn new_with_local_actor() -> Result<Self, anyhow::Error> {
74+
let router = DialMailboxRouter::new();
75+
let state_proc_id = ProcId(WorldId(format!("local_state_{}", ShortUuid::generate())), 0);
76+
let (state_proc_addr, state_rx) =
77+
channel::serve(ChannelAddr::any(ChannelTransport::Unix)).await?;
78+
let state_proc = Proc::new(
79+
state_proc_id.clone(),
80+
BoxedMailboxSender::new(router.clone()),
81+
);
82+
state_proc
83+
.clone()
84+
.serve(state_rx, mailbox::monitored_return_handle());
85+
router.bind(state_proc_id.clone().into(), state_proc_addr.clone());
86+
let state_actor = state_proc.spawn::<StateActor>("state", ()).await?.bind();
87+
88+
Ok(Self {
89+
_state_proc: Some(state_proc),
90+
state_proc_addr,
91+
state_actor,
92+
})
93+
}
94+
95+
/// Connect to an existing state actor to receive the remote logs.
96+
/// The lifecycle of the state actor should be maintained by the allocator who creates it.
97+
pub fn new_with_remote_actor(state_proc_id: ActorId, state_proc_addr: ChannelAddr) -> Self {
98+
Self {
99+
_state_proc: None,
100+
state_proc_addr,
101+
state_actor: ActorRef::attest(state_proc_id),
102+
}
103+
}
104+
105+
pub fn server_info(&self) -> StateServerInfo {
106+
StateServerInfo {
107+
state_proc_addr: self.state_proc_addr.clone(),
108+
state_actor_id: self.state_actor.actor_id().clone(),
109+
}
110+
}
111+
}
112+
113+
#[cfg(test)]
114+
mod tests {
115+
use std::str::FromStr;
116+
use std::time::Duration;
117+
118+
use hyperactor::channel;
119+
use hyperactor::channel::ChannelAddr;
120+
use hyperactor::clock::Clock;
121+
use hyperactor::id;
122+
use hyperactor_state::client::ClientActor;
123+
use hyperactor_state::client::ClientActorParams;
124+
use hyperactor_state::client::LogHandler;
125+
use hyperactor_state::object::GenericStateObject;
126+
use hyperactor_state::state_actor::StateMessageClient;
127+
use hyperactor_state::test_utils::log_items;
128+
use tokio::sync::mpsc::Sender;
129+
130+
use super::*;
131+
132+
#[test]
133+
fn test_state_server_info_serialization() {
134+
let addr = ChannelAddr::any(channel::ChannelTransport::Unix);
135+
let actor_id: ActorId = id!(test_actor[42].actor[0]);
136+
137+
let info = StateServerInfo {
138+
state_proc_addr: addr.clone(),
139+
state_actor_id: actor_id.clone(),
140+
};
141+
142+
// Test Display implementation
143+
let serialized = format!("{}", info);
144+
assert!(serialized.contains(","));
145+
146+
// Test FromStr implementation
147+
let deserialized = StateServerInfo::from_str(&serialized).unwrap();
148+
assert_eq!(
149+
info.state_proc_addr.to_string(),
150+
deserialized.state_proc_addr.to_string()
151+
);
152+
assert_eq!(
153+
info.state_actor_id.to_string(),
154+
deserialized.state_actor_id.to_string()
155+
);
156+
}
157+
158+
#[derive(Debug)]
159+
struct MpscLogHandler {
160+
sender: Sender<Vec<GenericStateObject>>,
161+
}
162+
163+
impl LogHandler for MpscLogHandler {
164+
fn handle_log(&self, logs: Vec<GenericStateObject>) -> anyhow::Result<()> {
165+
let sender = self.sender.clone();
166+
tokio::spawn(async move {
167+
sender.send(logs).await.unwrap();
168+
});
169+
Ok(())
170+
}
171+
}
172+
173+
#[tokio::test]
174+
async fn test_state_server_pushing_logs() {
175+
// Spin up a new state actor
176+
let log_source = LogSource::new_with_local_actor().await.unwrap();
177+
178+
// Setup the client and connect it to the state actor
179+
let router = DialMailboxRouter::new();
180+
let (client_proc_addr, client_rx) =
181+
channel::serve(ChannelAddr::any(ChannelTransport::Unix))
182+
.await
183+
.unwrap();
184+
let client_proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone()));
185+
client_proc
186+
.clone()
187+
.serve(client_rx, mailbox::monitored_return_handle());
188+
router.bind(id!(client[0]).into(), client_proc_addr.clone());
189+
router.bind(
190+
log_source.server_info().state_actor_id.clone().into(),
191+
log_source.server_info().state_proc_addr.clone(),
192+
);
193+
let client = client_proc.attach("client").unwrap();
194+
195+
// Spin up the client logging actor and subscribe to the state actor
196+
let (sender, mut receiver) = tokio::sync::mpsc::channel::<Vec<GenericStateObject>>(20);
197+
let log_handler = Box::new(MpscLogHandler { sender });
198+
let params = ClientActorParams { log_handler };
199+
200+
let client_logging_actor: ActorRef<ClientActor> = client_proc
201+
.spawn::<ClientActor>("logging_client", params)
202+
.await
203+
.unwrap()
204+
.bind();
205+
let state_actor_ref: ActorRef<StateActor> =
206+
ActorRef::attest(log_source.server_info().state_actor_id.clone());
207+
state_actor_ref
208+
.subscribe_logs(
209+
&client,
210+
client_proc_addr.clone(),
211+
client_logging_actor.clone(),
212+
)
213+
.await
214+
.unwrap();
215+
216+
// Write some logs
217+
state_actor_ref
218+
.push_logs(&client, log_items(0, 10))
219+
.await
220+
.unwrap();
221+
222+
// Collect received messages with timeout
223+
let fetched_logs = client_proc
224+
.clock()
225+
.timeout(Duration::from_secs(1), receiver.recv())
226+
.await
227+
.expect("timed out waiting for message")
228+
.expect("channel closed unexpectedly");
229+
230+
// Verify we received all expected logs
231+
assert_eq!(fetched_logs.len(), 10);
232+
assert_eq!(fetched_logs, log_items(0, 10));
233+
}
234+
}

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use hyperactor::actor::remote::Remote;
2525
use hyperactor::cap;
2626
use hyperactor::channel;
2727
use hyperactor::channel::ChannelAddr;
28-
use hyperactor::id;
2928
use hyperactor::mailbox;
3029
use hyperactor::mailbox::BoxableMailboxSender;
3130
use hyperactor::mailbox::BoxedMailboxSender;
@@ -57,6 +56,7 @@ use crate::alloc::ProcState;
5756
use crate::alloc::ProcStopReason;
5857
use crate::assign::Ranks;
5958
use crate::comm::CommActorMode;
59+
use crate::log_source::StateServerInfo;
6060
use crate::proc_mesh::mesh_agent::MeshAgent;
6161
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
6262
use crate::reference::ProcMeshId;
@@ -303,12 +303,11 @@ impl ProcMesh {
303303
}
304304

305305
// Get a reference to the state actor for streaming logs.
306-
// TODO: bind logging options to python API so that users can choose if to stream (optionally aggregated) logs back or not.
307-
// TODO: spin up state actor locally and remotely with names and addresses passed in here.
308-
let state_actor_id = id!(state_server[0].state[0]);
309-
let state_actor_ref = ActorRef::<StateActor>::attest(state_actor_id.clone());
310-
let state_actor_addr = "tcp![::]:3000".parse::<ChannelAddr>().unwrap();
311-
router.bind(state_actor_id.into(), state_actor_addr.clone());
306+
let StateServerInfo {
307+
state_proc_addr,
308+
state_actor_id,
309+
} = alloc.log_source().await?.server_info();
310+
router.bind(state_actor_id.clone().into(), state_proc_addr.clone());
312311

313312
let log_handler = Box::new(hyperactor_state::client::StdlogHandler {});
314313
let params = ClientActorParams { log_handler };
@@ -319,6 +318,7 @@ impl ProcMesh {
319318
.unwrap()
320319
.bind();
321320

321+
let state_actor_ref: ActorRef<StateActor> = ActorRef::attest(state_actor_id);
322322
match state_actor_ref
323323
.subscribe_logs(
324324
&client,

0 commit comments

Comments
 (0)