Skip to content

Commit 0ab4c85

Browse files
highkerfacebook-github-bot
authored andcommitted
(2/n) get ready to optionally spin up state actor on the client
Differential Revision: D77848229
1 parent d03d512 commit 0ab4c85

File tree

4 files changed

+253
-7
lines changed

4 files changed

+253
-7
lines changed

hyperactor_mesh/src/alloc.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ use serde::Deserialize;
3636
use serde::Serialize;
3737

3838
use crate::alloc::test_utils::MockAllocWrapper;
39+
use crate::log_source::LogSource;
3940
use crate::proc_mesh::mesh_agent::MeshAgent;
4041

4142
/// Errors that occur during allocation operations.
@@ -219,6 +220,20 @@ pub trait Alloc {
219220
/// The channel transport used the procs in this alloc.
220221
fn transport(&self) -> ChannelTransport;
221222

223+
/// The log source for this alloc.
224+
/// It allows remote processes to stream stdout and stderr back to the client.
225+
/// A client can connect to the log source to obtain the streamed logs.
226+
/// A log source is allocation specific. Each allocator can decide how to stream the logs back.
227+
async fn log_source(&self) -> Result<LogSource, AllocatorError> {
228+
// TODO: this should be implemented based on different allocators.
229+
// Having this temporarily here so that the client can connect to the log source.
230+
// But the client will not get anything.
231+
// The following diffs will gradually implement this for different allocators.
232+
LogSource::new_with_local_actor()
233+
.await
234+
.map_err(AllocatorError::from)
235+
}
236+
222237
/// Stop this alloc, shutting down all of its procs. A clean
223238
/// shutdown should result in Stop events from all allocs,
224239
/// followed by the end of the event stream.

hyperactor_mesh/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ pub mod bootstrap;
1919
pub mod code_sync;
2020
pub mod comm;
2121
pub mod connect;
22+
pub mod log_source;
2223
pub mod mesh;
2324
pub mod mesh_selection;
2425
mod metrics;

hyperactor_mesh/src/log_source.rs

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
use std::fmt;
10+
use std::str::FromStr;
11+
12+
use hyperactor::ActorId;
13+
use hyperactor::ActorRef;
14+
use hyperactor::channel;
15+
use hyperactor::channel::ChannelAddr;
16+
use hyperactor::channel::ChannelTransport;
17+
use hyperactor::id;
18+
use hyperactor::mailbox;
19+
use hyperactor::mailbox::BoxedMailboxSender;
20+
use hyperactor::mailbox::DialMailboxRouter;
21+
use hyperactor::mailbox::MailboxServer;
22+
use hyperactor::proc::Proc;
23+
use hyperactor_state::state_actor::StateActor;
24+
25+
/// The source of the log so that the remote process can stream stdout and stderr to.
26+
/// A log source is allocation specific. Each allocator can decide how to stream the logs back.
27+
/// It holds a reference or the lifecycle of a state actor that collects all the logs from processes.
28+
#[derive(Clone, Debug)]
29+
pub struct LogSource {
30+
// Optionally hold the lifecycle of the state actor
31+
_state_proc: Option<Proc>,
32+
state_proc_addr: ChannelAddr,
33+
state_actor: ActorRef<StateActor>,
34+
}
35+
36+
#[derive(Clone, Debug)]
37+
pub struct StateServerInfo {
38+
pub state_proc_addr: ChannelAddr,
39+
pub state_actor_id: ActorId,
40+
}
41+
42+
impl fmt::Display for StateServerInfo {
43+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44+
write!(f, "{},{}", self.state_proc_addr, self.state_actor_id)
45+
}
46+
}
47+
48+
impl FromStr for StateServerInfo {
49+
type Err = anyhow::Error;
50+
51+
fn from_str(state_server_info: &str) -> Result<Self, Self::Err> {
52+
match state_server_info.split_once(",") {
53+
Some((addr, id)) => {
54+
let state_proc_addr: ChannelAddr = addr.parse()?;
55+
let state_actor_id: ActorId = id.parse()?;
56+
Ok(StateServerInfo {
57+
state_proc_addr,
58+
state_actor_id,
59+
})
60+
}
61+
_ => Err(anyhow::anyhow!(
62+
"unrecognized state server info: {state_server_info}"
63+
)),
64+
}
65+
}
66+
}
67+
68+
impl LogSource {
69+
/// Spin up the state actor locally to receive the remote logs.
70+
pub async fn new_with_local_actor() -> Result<Self, anyhow::Error> {
71+
let router = DialMailboxRouter::new();
72+
let state_proc_id = id!(local_state[0]);
73+
let (state_proc_addr, state_rx) =
74+
channel::serve(ChannelAddr::any(ChannelTransport::Unix)).await?;
75+
let state_proc = Proc::new(
76+
state_proc_id.clone(),
77+
BoxedMailboxSender::new(router.clone()),
78+
);
79+
state_proc
80+
.clone()
81+
.serve(state_rx, mailbox::monitored_return_handle());
82+
router.bind(state_proc_id.clone().into(), state_proc_addr.clone());
83+
let state_actor = state_proc.spawn::<StateActor>("state", ()).await?.bind();
84+
85+
Ok(Self {
86+
_state_proc: Some(state_proc),
87+
state_proc_addr,
88+
state_actor,
89+
})
90+
}
91+
92+
/// Connect to an existing state actor to receive the remote logs.
93+
/// The lifecycle of the state actor should be maintained by the allocator who creates it.
94+
pub fn new_with_remote_actor(state_proc_id: ActorId, state_proc_addr: ChannelAddr) -> Self {
95+
Self {
96+
_state_proc: None,
97+
state_proc_addr,
98+
state_actor: ActorRef::attest(state_proc_id),
99+
}
100+
}
101+
102+
pub fn server_info(&self) -> StateServerInfo {
103+
StateServerInfo {
104+
state_proc_addr: self.state_proc_addr.clone(),
105+
state_actor_id: self.state_actor.actor_id().clone(),
106+
}
107+
}
108+
}
109+
110+
#[cfg(test)]
111+
mod tests {
112+
use std::str::FromStr;
113+
use std::time::Duration;
114+
115+
use hyperactor::channel;
116+
use hyperactor::channel::ChannelAddr;
117+
use hyperactor::clock::Clock;
118+
use hyperactor_state::client::ClientActor;
119+
use hyperactor_state::client::ClientActorParams;
120+
use hyperactor_state::client::LogHandler;
121+
use hyperactor_state::object::GenericStateObject;
122+
use hyperactor_state::state_actor::StateMessageClient;
123+
use hyperactor_state::test_utils::log_items;
124+
use tokio::sync::mpsc::Sender;
125+
126+
use super::*;
127+
128+
#[test]
129+
fn test_state_server_info_serialization() {
130+
let addr = ChannelAddr::any(channel::ChannelTransport::Unix);
131+
let actor_id: ActorId = id!(test_actor[42].actor[0]);
132+
133+
let info = StateServerInfo {
134+
state_proc_addr: addr.clone(),
135+
state_actor_id: actor_id.clone(),
136+
};
137+
138+
// Test Display implementation
139+
let serialized = format!("{}", info);
140+
assert!(serialized.contains(","));
141+
142+
// Test FromStr implementation
143+
let deserialized = StateServerInfo::from_str(&serialized).unwrap();
144+
assert_eq!(
145+
info.state_proc_addr.to_string(),
146+
deserialized.state_proc_addr.to_string()
147+
);
148+
assert_eq!(
149+
info.state_actor_id.to_string(),
150+
deserialized.state_actor_id.to_string()
151+
);
152+
}
153+
154+
#[derive(Debug)]
155+
struct MpscLogHandler {
156+
sender: Sender<Vec<GenericStateObject>>,
157+
}
158+
159+
impl LogHandler for MpscLogHandler {
160+
fn handle_log(&self, logs: Vec<GenericStateObject>) -> anyhow::Result<()> {
161+
let sender = self.sender.clone();
162+
tokio::spawn(async move {
163+
sender.send(logs).await.unwrap();
164+
});
165+
Ok(())
166+
}
167+
}
168+
169+
#[tokio::test]
170+
async fn test_state_server_pushing_logs() {
171+
// Spin up a new state actor
172+
let log_source = LogSource::new_with_local_actor().await.unwrap();
173+
174+
// Setup the client and connect it to the state actor
175+
let router = DialMailboxRouter::new();
176+
let (client_proc_addr, client_rx) =
177+
channel::serve(ChannelAddr::any(ChannelTransport::Unix))
178+
.await
179+
.unwrap();
180+
let client_proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone()));
181+
client_proc
182+
.clone()
183+
.serve(client_rx, mailbox::monitored_return_handle());
184+
router.bind(id!(client[0]).into(), client_proc_addr.clone());
185+
router.bind(
186+
log_source.server_info().state_actor_id.clone().into(),
187+
log_source.server_info().state_proc_addr.clone(),
188+
);
189+
let client = client_proc.attach("client").unwrap();
190+
191+
// Spin up the client logging actor and subscribe to the state actor
192+
let (sender, mut receiver) = tokio::sync::mpsc::channel::<Vec<GenericStateObject>>(20);
193+
let log_handler = Box::new(MpscLogHandler { sender });
194+
let params = ClientActorParams { log_handler };
195+
196+
let client_logging_actor: ActorRef<ClientActor> = client_proc
197+
.spawn::<ClientActor>("logging_client", params)
198+
.await
199+
.unwrap()
200+
.bind();
201+
let state_actor_ref: ActorRef<StateActor> =
202+
ActorRef::attest(log_source.server_info().state_actor_id.clone());
203+
state_actor_ref
204+
.subscribe_logs(
205+
&client,
206+
client_proc_addr.clone(),
207+
client_logging_actor.clone(),
208+
)
209+
.await
210+
.unwrap();
211+
212+
// Write some logs
213+
state_actor_ref
214+
.push_logs(&client, log_items(0, 10))
215+
.await
216+
.unwrap();
217+
218+
// Collect received messages with timeout
219+
let fetched_logs = client_proc
220+
.clock()
221+
.timeout(Duration::from_secs(1), receiver.recv())
222+
.await
223+
.expect("timed out waiting for message")
224+
.expect("channel closed unexpectedly");
225+
226+
// Verify we received all expected logs
227+
assert_eq!(fetched_logs.len(), 10);
228+
assert_eq!(fetched_logs, log_items(0, 10));
229+
}
230+
}

hyperactor_mesh/src/proc_mesh.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ use hyperactor::actor::remote::Remote;
2525
use hyperactor::cap;
2626
use hyperactor::channel;
2727
use hyperactor::channel::ChannelAddr;
28-
use hyperactor::id;
2928
use hyperactor::mailbox;
3029
use hyperactor::mailbox::BoxableMailboxSender;
3130
use hyperactor::mailbox::BoxedMailboxSender;
@@ -57,6 +56,7 @@ use crate::alloc::ProcState;
5756
use crate::alloc::ProcStopReason;
5857
use crate::assign::Ranks;
5958
use crate::comm::CommActorMode;
59+
use crate::log_source::StateServerInfo;
6060
use crate::proc_mesh::mesh_agent::MeshAgent;
6161
use crate::proc_mesh::mesh_agent::MeshAgentMessageClient;
6262

@@ -302,12 +302,11 @@ impl ProcMesh {
302302
}
303303

304304
// Get a reference to the state actor for streaming logs.
305-
// TODO: bind logging options to python API so that users can choose if to stream (optionally aggregated) logs back or not.
306-
// TODO: spin up state actor locally and remotely with names and addresses passed in here.
307-
let state_actor_id = id!(state_server[0].state[0]);
308-
let state_actor_ref = ActorRef::<StateActor>::attest(state_actor_id.clone());
309-
let state_actor_addr = "tcp![::]:3000".parse::<ChannelAddr>().unwrap();
310-
router.bind(state_actor_id.into(), state_actor_addr.clone());
305+
let StateServerInfo {
306+
state_proc_addr,
307+
state_actor_id,
308+
} = alloc.log_source().await?.server_info();
309+
router.bind(state_actor_id.clone().into(), state_proc_addr.clone());
311310

312311
let log_handler = Box::new(hyperactor_state::client::StdlogHandler {});
313312
let params = ClientActorParams { log_handler };
@@ -318,6 +317,7 @@ impl ProcMesh {
318317
.unwrap()
319318
.bind();
320319

320+
let state_actor_ref: ActorRef<StateActor> = ActorRef::attest(state_actor_id);
321321
match state_actor_ref
322322
.subscribe_logs(
323323
&client,

0 commit comments

Comments
 (0)