diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index 9c36f525..0234970d 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -223,15 +223,7 @@ pub trait Alloc { /// It allows remote processes to stream stdout and stderr back to the client. /// A client can connect to the log source to obtain the streamed logs. /// A log source is allocation specific. Each allocator can decide how to stream the logs back. - async fn log_source(&self) -> Result { - // TODO: this should be implemented based on different allocators. - // Having this temporarily here so that the client can connect to the log source. - // But the client will not get anything. - // The following diffs will gradually implement this for different allocators. - LogSource::new_with_local_actor() - .await - .map_err(AllocatorError::from) - } + async fn log_source(&self) -> Result; /// Stop this alloc, shutting down all of its procs. A clean /// shutdown should result in Stop events from all allocs, @@ -367,6 +359,10 @@ pub mod test_utils { self.alloc.transport() } + async fn log_source(&self) -> Result { + self.alloc.log_source().await + } + async fn stop(&mut self) -> Result<(), AllocatorError> { self.alloc.stop().await } diff --git a/hyperactor_mesh/src/alloc/local.rs b/hyperactor_mesh/src/alloc/local.rs index 5950eaa3..63115fa2 100644 --- a/hyperactor_mesh/src/alloc/local.rs +++ b/hyperactor_mesh/src/alloc/local.rs @@ -33,6 +33,7 @@ use crate::alloc::AllocSpec; use crate::alloc::Allocator; use crate::alloc::AllocatorError; use crate::alloc::ProcState; +use crate::log_source::LogSource; use crate::proc_mesh::mesh_agent::MeshAgent; use crate::shortuuid::ShortUuid; @@ -252,6 +253,14 @@ impl Alloc for LocalAlloc { ChannelTransport::Local } + async fn log_source(&self) -> Result { + // Local alloc does not need to stream logs back. + // The client can subscribe to it but local actors will not stream logs into it. + LogSource::new_with_local_actor() + .await + .map_err(AllocatorError::from) + } + async fn stop(&mut self) -> Result<(), AllocatorError> { for rank in 0..self.size() { self.todo_tx diff --git a/hyperactor_mesh/src/alloc/process.rs b/hyperactor_mesh/src/alloc/process.rs index 14807f0f..e2cdc337 100644 --- a/hyperactor_mesh/src/alloc/process.rs +++ b/hyperactor_mesh/src/alloc/process.rs @@ -29,7 +29,6 @@ use hyperactor::channel::ChannelTx; use hyperactor::channel::Rx; use hyperactor::channel::Tx; use hyperactor::channel::TxStatus; -use hyperactor::id; use hyperactor::sync::flag; use hyperactor::sync::monitor; use hyperactor_state::state_actor::StateActor; @@ -53,6 +52,8 @@ use crate::bootstrap; use crate::bootstrap::Allocator2Process; use crate::bootstrap::Process2Allocator; use crate::bootstrap::Process2AllocatorMessage; +use crate::log_source::LogSource; +use crate::log_source::StateServerInfo; use crate::shortuuid::ShortUuid; /// The maximum number of log lines to tail keep for managed processes. @@ -89,6 +90,9 @@ impl Allocator for ProcessAllocator { let (bootstrap_addr, rx) = channel::serve(ChannelAddr::any(ChannelTransport::Unix)) .await .map_err(anyhow::Error::from)?; + let log_source = LogSource::new_with_local_actor() + .await + .map_err(AllocatorError::from)?; let name = ShortUuid::generate(); let n = spec.shape.slice().len(); @@ -97,6 +101,7 @@ impl Allocator for ProcessAllocator { world_id: WorldId(name.to_string()), spec: spec.clone(), bootstrap_addr, + log_source, rx, index: 0, active: HashMap::new(), @@ -115,6 +120,7 @@ pub struct ProcessAlloc { world_id: WorldId, // to provide storage spec: AllocSpec, bootstrap_addr: ChannelAddr, + log_source: LogSource, rx: channel::ChannelRx, index: usize, active: HashMap, @@ -145,6 +151,7 @@ struct Child { impl Child { fn monitored( mut process: tokio::process::Child, + state_server_info: StateServerInfo, ) -> (Self, impl Future) { let (group, handle) = monitor::group(); let (exit_flag, exit_guard) = flag::guarded(); @@ -161,24 +168,20 @@ impl Child { // If state actor is enabled, try to set up LogWriter instances if use_state_actor { - let state_actor_ref = ActorRef::::attest(id!(state_server[0].state[0])); - // Parse the state actor address - if let Ok(state_actor_addr) = "tcp![::]:3000".parse::() { - // Use the helper function to create both writers at once - match hyperactor_state::log_writer::create_log_writers( - state_actor_addr, - state_actor_ref, - ) { - Ok((stdout_writer, stderr_writer)) => { - stdout_tee = stdout_writer; - stderr_tee = stderr_writer; - } - Err(e) => { - tracing::error!("failed to create log writers: {}", e); - } + let state_actor_ref = ActorRef::::attest(state_server_info.state_actor_id); + let state_actor_addr = state_server_info.state_proc_addr; + // Use the helper function to create both writers at once + match hyperactor_state::log_writer::create_log_writers( + state_actor_addr, + state_actor_ref, + ) { + Ok((stdout_writer, stderr_writer)) => { + stdout_tee = stdout_writer; + stderr_tee = stderr_writer; + } + Err(e) => { + tracing::error!("failed to create log writers: {}", e); } - } else { - tracing::error!("failed to parse state actor address"); } } @@ -394,7 +397,8 @@ impl ProcessAlloc { None } Ok(rank) => { - let (handle, monitor) = Child::monitored(process); + let (handle, monitor) = + Child::monitored(process, self.log_source.server_info()); self.children.spawn(async move { (index, monitor.await) }); self.active.insert(index, handle); // Adjust for shape slice offset for non-zero shapes (sub-shapes). @@ -498,6 +502,10 @@ impl Alloc for ProcessAlloc { ChannelTransport::Unix } + async fn log_source(&self) -> Result { + Ok(self.log_source.clone()) + } + async fn stop(&mut self) -> Result<(), AllocatorError> { // We rely on the teardown here, and that the process should // exit on its own. We should have a hard timeout here as well, diff --git a/hyperactor_mesh/src/alloc/remoteprocess.rs b/hyperactor_mesh/src/alloc/remoteprocess.rs index ebc6780f..87e97a72 100644 --- a/hyperactor_mesh/src/alloc/remoteprocess.rs +++ b/hyperactor_mesh/src/alloc/remoteprocess.rs @@ -16,6 +16,8 @@ use anyhow::Context; use async_trait::async_trait; use futures::FutureExt; use futures::future::select_all; +use hyperactor::ActorRef; +use hyperactor::Mailbox; use hyperactor::Named; use hyperactor::ProcId; use hyperactor::WorldId; @@ -30,11 +32,21 @@ use hyperactor::channel::TxStatus; use hyperactor::clock; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; +use hyperactor::id; +use hyperactor::mailbox; +use hyperactor::mailbox::BoxedMailboxSender; use hyperactor::mailbox::DialMailboxRouter; use hyperactor::mailbox::MailboxServer; use hyperactor::mailbox::monitored_return_handle; +use hyperactor::proc::Proc; use hyperactor::reference::Reference; use hyperactor::serde_json; +use hyperactor_state::client::ClientActor; +use hyperactor_state::client::ClientActorParams; +use hyperactor_state::client::LogHandler; +use hyperactor_state::object::GenericStateObject; +use hyperactor_state::state_actor::StateActor; +use hyperactor_state::state_actor::StateMessageClient; use mockall::automock; use ndslice::Shape; use serde::Deserialize; @@ -56,6 +68,9 @@ use crate::alloc::AllocatorError; use crate::alloc::ProcState; use crate::alloc::ProcStopReason; use crate::alloc::ProcessAllocator; +use crate::log_source::LogSource; +use crate::log_source::StateServerInfo; +use crate::shortuuid::ShortUuid; /// Control messages sent from remote process allocator to local allocator. #[derive(Debug, Clone, Serialize, Deserialize, Named)] @@ -67,6 +82,8 @@ pub enum RemoteProcessAllocatorMessage { spec: AllocSpec, /// Bootstrap address to be used for sending updates. bootstrap_addr: ChannelAddr, + /// The location of the state actor. + state_server_info: StateServerInfo, /// Ordered list of hosts in this allocation. Can be used to /// pre-populate the any local configurations such as torch.dist. hosts: Vec, @@ -94,6 +111,26 @@ pub enum RemoteProcessProcStateMessage { HeartBeat, } +#[derive(Debug)] +struct ForwarderLogHandler { + parent_state_actor: ActorRef, + forwarder: Mailbox, +} + +impl LogHandler for ForwarderLogHandler { + fn handle_log(&self, logs: Vec) -> anyhow::Result<()> { + let actor = self.parent_state_actor.clone(); + let forwarder = self.forwarder.clone(); + // TODO: (@jamessun) this is horribly wrong. ClientActor's log handler needs to be type erased + // so that the state actor can be subscribed by different kinds of ClientActor. + // However, async function and Box do not work well together. + tokio::spawn(async move { + actor.push_logs(&forwarder, logs).await.unwrap(); + }); + Ok(()) + } +} + /// Allocator with a service frontend that wraps ProcessAllocator. pub struct RemoteProcessAllocator { cancel_token: CancellationToken, @@ -164,6 +201,24 @@ impl RemoteProcessAllocator { } } + // Setup a proc and an actor to forward the child process's log to the state actor. + // This child-parent-state actor relay is not needed. But it helps with abstraction. + let router = DialMailboxRouter::new(); + let (forwarder_proc_addr, forwarder_rx) = + channel::serve(ChannelAddr::any(ChannelTransport::Unix)) + .await + .unwrap(); + let forwarder_proc_id = id!(forwarder[0]); + let forwarder_proc = Proc::new( + forwarder_proc_id.clone(), + BoxedMailboxSender::new(router.clone()), + ); + forwarder_proc + .clone() + .serve(forwarder_rx, mailbox::monitored_return_handle()); + router.bind(forwarder_proc_id.into(), forwarder_proc_addr.clone()); + let forwarder = forwarder_proc.attach("forwarder").unwrap(); + let mut active_allocation: Option = None; loop { tokio::select! { @@ -172,15 +227,52 @@ impl RemoteProcessAllocator { Ok(RemoteProcessAllocatorMessage::Allocate { spec, bootstrap_addr, + state_server_info, hosts, heartbeat_interval, }) => { tracing::info!("received allocation request: {:?}", spec); + let parent_state_actor_id = state_server_info.state_actor_id.clone(); + router.bind(parent_state_actor_id.clone().into(), state_server_info.state_proc_addr.clone()); ensure_previous_alloc_stopped(&mut active_allocation).await; match process_allocator.allocate(spec.clone()).await { Ok(alloc) => { + let child_state_server_info = alloc.log_source().await?.server_info(); + let child_state_actor_id = child_state_server_info.state_actor_id.clone(); + if child_state_actor_id == parent_state_actor_id { + // In general, this is unlikely. But if it happens, it will get into infinite forwarding loop. + anyhow::bail!("found duplicated state actor ids: {}, {}", child_state_actor_id, parent_state_actor_id); + } + + router.bind( + child_state_actor_id.clone().into(), + child_state_server_info.state_proc_addr.clone(), + ); + tracing::info!("receiving log from {} and forwarding to {}", child_state_actor_id, parent_state_actor_id); + + // Spin up the client actor, subscribe to the child process, and forward the log to the state actor. + let log_handler = Box::new(ForwarderLogHandler { + parent_state_actor: ActorRef::attest(parent_state_actor_id), + forwarder: forwarder.clone(), + }); + let params = ClientActorParams { log_handler }; + + // Use UUID as there could be multiple allocations. + let forwarder_client_actor: ActorRef = forwarder_proc + .spawn::(&format!("forwarder_client{}", ShortUuid::generate()), params) + .await? + .bind(); + let child_state_actor_ref: ActorRef = ActorRef::attest(child_state_actor_id); + child_state_actor_ref + .subscribe_logs( + &forwarder, + forwarder_proc_addr.clone(), + forwarder_client_actor.clone(), + ) + .await?; + let cancel_token = CancellationToken::new(); active_allocation = Some(ActiveAllocation { cancel_token: cancel_token.clone(), @@ -437,6 +529,11 @@ struct RemoteProcessAllocHostState { pub trait RemoteProcessAllocInitializer { /// Initializes and returns a list of hosts to be used by this RemoteProcessAlloc. async fn initialize_alloc(&mut self) -> Result, anyhow::Error>; + + async fn initialize_state_actor(&self) -> Result { + // TODO (@lky): this needs to be scheduler specific. Let's implement it for MAST and python initializer. + LogSource::new_with_local_actor().await + } } /// A generalized implementation of an Alloc using one or more hosts running @@ -466,6 +563,8 @@ pub struct RemoteProcessAlloc { bootstrap_addr: ChannelAddr, rx: ChannelRx, + + log_source: LogSource, } impl RemoteProcessAlloc { @@ -491,6 +590,8 @@ impl RemoteProcessAlloc { bootstrap_addr.clone() ); + let log_source = initializer.initialize_state_actor().await?; + let (comm_watcher_tx, comm_watcher_rx) = unbounded_channel(); Ok(Self { @@ -505,6 +606,7 @@ impl RemoteProcessAlloc { hosts_by_offset: HashMap::new(), host_states: HashMap::new(), bootstrap_addr, + log_source, event_queue: VecDeque::new(), comm_watcher_tx, comm_watcher_rx, @@ -633,6 +735,7 @@ impl RemoteProcessAlloc { ))?; tx.post(RemoteProcessAllocatorMessage::Allocate { bootstrap_addr: self.bootstrap_addr.clone(), + state_server_info: self.log_source.server_info().clone(), spec: AllocSpec { shape: host_shape.clone(), constraints: self.spec.constraints.clone(), @@ -981,6 +1084,10 @@ impl Alloc for RemoteProcessAlloc { self.transport.clone() } + async fn log_source(&self) -> Result { + Ok(self.log_source.clone()) + } + async fn stop(&mut self) -> Result<(), AllocatorError> { tracing::info!("stopping alloc"); @@ -1001,7 +1108,9 @@ mod test { use hyperactor::channel::ChannelRx; use hyperactor::clock::ClockKind; use hyperactor::id; + use hyperactor_state::test_utils::log_items; use ndslice::shape; + use tokio::sync::mpsc::Sender; use tokio::sync::oneshot; use super::*; @@ -1012,6 +1121,21 @@ mod test { use crate::alloc::ProcStopReason; use crate::proc_mesh::mesh_agent::MeshAgent; + #[derive(Debug)] + struct MpscLogHandler { + sender: Sender>, + } + + impl LogHandler for MpscLogHandler { + fn handle_log(&self, logs: Vec) -> anyhow::Result<()> { + let sender = self.sender.clone(); + tokio::spawn(async move { + sender.send(logs).await.unwrap(); + }); + Ok(()) + } + } + async fn read_all_created(rx: &mut ChannelRx, alloc_len: usize) { let mut i: usize = 0; while i < alloc_len { @@ -1102,6 +1226,7 @@ mod test { let alloc_len = spec.shape.slice().len(); let world_id: WorldId = id!(test_world_id); + let log_source = LogSource::new_with_local_actor().await.unwrap(); let mut alloc = MockAlloc::new(); alloc.expect_world_id().return_const(world_id.clone()); alloc.expect_shape().return_const(spec.shape.clone()); @@ -1110,6 +1235,10 @@ mod test { // final none alloc.expect_next().return_const(None); + alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source)); let mut allocator = MockAllocator::new(); let total_messages = alloc_len * 3 + 1; @@ -1136,6 +1265,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr, + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_secs(1), }) @@ -1240,6 +1373,7 @@ mod test { let alloc_len = spec.shape.slice().len(); let world_id: WorldId = id!(test_world_id); + let log_source = LogSource::new_with_local_actor().await.unwrap(); let mut alloc = MockAllocWrapper::new_block_next( MockAlloc::new(), // block after all created, all running @@ -1253,6 +1387,11 @@ mod test { alloc.alloc.expect_next().return_const(None); alloc.alloc.expect_stop().times(1).return_once(|| Ok(())); + alloc + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source)); let mut allocator = MockAllocator::new(); allocator @@ -1273,6 +1412,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr, + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_millis(200), }) @@ -1298,6 +1441,156 @@ mod test { handle.await.unwrap().unwrap(); } + #[timed_test::async_timed_test(timeout_secs = 15)] + async fn test_log_streaming() { + hyperactor_telemetry::initialize_logging(ClockKind::default()); + let serve_addr = ChannelAddr::any(ChannelTransport::Unix); + let bootstrap_addr = ChannelAddr::any(ChannelTransport::Unix); + let (_, mut rx) = channel::serve(bootstrap_addr.clone()).await.unwrap(); + + let spec = AllocSpec { + shape: shape!(host = 1, gpu = 2), + constraints: Default::default(), + }; + let tx = channel::dial(serve_addr.clone()).unwrap(); + + let alloc_len = spec.shape.slice().len(); + + let world_id: WorldId = id!(test_world_id); + let child_log_source = LogSource::new_with_local_actor().await.unwrap(); + let mut alloc = MockAllocWrapper::new_block_next( + MockAlloc::new(), + // block after all created, all running + alloc_len * 2, + ); + let next_tx = alloc.notify_tx(); + alloc.alloc.expect_world_id().return_const(world_id.clone()); + alloc.alloc.expect_shape().return_const(spec.shape.clone()); + + set_procstate_expectations(&mut alloc.alloc, spec.shape.clone()); + + alloc.alloc.expect_next().return_const(None); + alloc.alloc.expect_stop().times(1).return_once(|| Ok(())); + let child_log_source_clone = child_log_source.clone(); + alloc + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(child_log_source_clone)); + + let mut allocator = MockAllocator::new(); + allocator + .expect_allocate() + .times(1) + .return_once(|_| Ok(alloc)); + + let remote_allocator = RemoteProcessAllocator::new(); + let handle = tokio::spawn({ + let remote_allocator = remote_allocator.clone(); + async move { + remote_allocator + .start_with_allocator(serve_addr, allocator) + .await + } + }); + + let parent_log_source = LogSource::new_with_local_actor().await.unwrap(); + tx.send(RemoteProcessAllocatorMessage::Allocate { + spec: spec.clone(), + bootstrap_addr, + state_server_info: parent_log_source.server_info(), + hosts: vec![], + heartbeat_interval: Duration::from_millis(200), + }) + .await + .unwrap(); + + // Allocated + let m = rx.recv().await.unwrap(); + assert_matches!(m, RemoteProcessProcStateMessage::Allocated {world_id, shape} if world_id == world_id && shape == spec.shape); + + read_all_created(&mut rx, alloc_len).await; + read_all_running(&mut rx, alloc_len).await; + + // Push some log to child state actor and subscribe to the parent state actor + let router = DialMailboxRouter::new(); + let (client_proc_addr, client_rx) = + channel::serve(ChannelAddr::any(ChannelTransport::Unix)) + .await + .unwrap(); + let client_proc = Proc::new(id!(client[0]), BoxedMailboxSender::new(router.clone())); + client_proc + .clone() + .serve(client_rx, mailbox::monitored_return_handle()); + router.bind(id!(client[0]).into(), client_proc_addr.clone()); + router.bind( + parent_log_source + .server_info() + .state_actor_id + .clone() + .into(), + parent_log_source.server_info().state_proc_addr.clone(), + ); + router.bind( + child_log_source.server_info().state_actor_id.clone().into(), + child_log_source.server_info().state_proc_addr.clone(), + ); + let client = client_proc.attach("client").unwrap(); + + // Spin up the client logging actor and subscribe to the state actor + let (sender, mut receiver) = tokio::sync::mpsc::channel::>(20); + let log_handler = Box::new(MpscLogHandler { sender }); + let params = ClientActorParams { log_handler }; + + let client_logging_actor: ActorRef = client_proc + .spawn::("logging_client", params) + .await + .unwrap() + .bind(); + let parent_state_actor_ref: ActorRef = + ActorRef::attest(parent_log_source.server_info().state_actor_id.clone()); + let child_state_actor_ref: ActorRef = + ActorRef::attest(child_log_source.server_info().state_actor_id.clone()); + + // Listen to the parent + parent_state_actor_ref + .subscribe_logs( + &client, + client_proc_addr.clone(), + client_logging_actor.clone(), + ) + .await + .unwrap(); + + // Write to the child + child_state_actor_ref + .push_logs(&client, log_items(0, 10)) + .await + .unwrap(); + + // Collect received messages with timeout + let fetched_logs = client_proc + .clock() + .timeout(Duration::from_secs(1), receiver.recv()) + .await + .expect("timed out waiting for message") + .expect("channel closed unexpectedly"); + + // Verify we received all expected logs + assert_eq!(fetched_logs.len(), 10); + assert_eq!(fetched_logs, log_items(0, 10)); + + // allocation finished. now we stop it. + tx.send(RemoteProcessAllocatorMessage::Stop).await.unwrap(); + // receive all stops + next_tx.send(()).unwrap(); + + read_all_stopped(&mut rx, alloc_len).await; + + remote_allocator.terminate(); + handle.await.unwrap().unwrap(); + } + #[timed_test::async_timed_test(timeout_secs = 15)] async fn test_realloc() { hyperactor_telemetry::initialize_logging(ClockKind::default()); @@ -1314,6 +1607,8 @@ mod test { let alloc_len = spec.shape.slice().len(); let world_id: WorldId = id!(test_world_id); + let log_source1 = LogSource::new_with_local_actor().await.unwrap(); + let log_source2 = LogSource::new_with_local_actor().await.unwrap(); let mut alloc1 = MockAllocWrapper::new_block_next( MockAlloc::new(), // block after all created, all running @@ -1329,6 +1624,12 @@ mod test { set_procstate_expectations(&mut alloc1.alloc, spec.shape.clone()); alloc1.alloc.expect_next().return_const(None); alloc1.alloc.expect_stop().times(1).return_once(|| Ok(())); + alloc1 + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source1)); + // second allocation let mut alloc2 = MockAllocWrapper::new_block_next( MockAlloc::new(), @@ -1344,6 +1645,11 @@ mod test { set_procstate_expectations(&mut alloc2.alloc, spec.shape.clone()); alloc2.alloc.expect_next().return_const(None); alloc2.alloc.expect_stop().times(1).return_once(|| Ok(())); + alloc2 + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source2)); let mut allocator = MockAllocator::new(); allocator @@ -1369,6 +1675,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr: bootstrap_addr.clone(), + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_millis(200), }) @@ -1386,6 +1696,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr, + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_millis(200), }) @@ -1437,6 +1751,7 @@ mod test { let alloc_len = spec.shape.slice().len(); let world_id: WorldId = id!(test_world_id); + let log_source = LogSource::new_with_local_actor().await.unwrap(); let mut alloc = MockAllocWrapper::new_block_next( MockAlloc::new(), // block after all created, all running @@ -1445,6 +1760,11 @@ mod test { let next_tx = alloc.notify_tx(); alloc.alloc.expect_world_id().return_const(world_id.clone()); alloc.alloc.expect_shape().return_const(spec.shape.clone()); + alloc + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source)); set_procstate_expectations(&mut alloc.alloc, spec.shape.clone()); @@ -1476,6 +1796,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr, + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_millis(200), }) @@ -1517,6 +1841,7 @@ mod test { let tx = channel::dial(serve_addr.clone()).unwrap(); let test_world_id: WorldId = id!(test_world_id); + let log_source = LogSource::new_with_local_actor().await.unwrap(); let mut alloc = MockAllocWrapper::new_block_next( MockAlloc::new(), // block after the failure update @@ -1539,6 +1864,11 @@ mod test { alloc.alloc.expect_next().times(1).return_const(None); alloc.alloc.expect_stop().times(1).return_once(|| Ok(())); + alloc + .alloc + .expect_log_source() + .times(1) + .return_once(move || Ok(log_source)); let mut allocator = MockAllocator::new(); allocator @@ -1559,6 +1889,10 @@ mod test { tx.send(RemoteProcessAllocatorMessage::Allocate { spec: spec.clone(), bootstrap_addr, + state_server_info: LogSource::new_with_local_actor() + .await + .unwrap() + .server_info(), hosts: vec![], heartbeat_interval: Duration::from_secs(60), }) @@ -1650,6 +1984,10 @@ mod test_alloc { }, ]) }); + let log_source = LogSource::new_with_local_actor().await.unwrap(); + initializer + .expect_initialize_state_actor() + .return_once(move || Ok(log_source)); let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) .await @@ -1775,6 +2113,10 @@ mod test_alloc { }, ]) }); + let log_source = LogSource::new_with_local_actor().await.unwrap(); + initializer + .expect_initialize_state_actor() + .return_once(move || Ok(log_source)); let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) .await @@ -1894,6 +2236,10 @@ mod test_alloc { }, ]) }); + let log_source = LogSource::new_with_local_actor().await.unwrap(); + initializer + .expect_initialize_state_actor() + .return_once(move || Ok(log_source)); let mut alloc = RemoteProcessAlloc::new(spec.clone(), world_id, transport, 0, heartbeat, initializer) .await diff --git a/hyperactor_mesh/src/log_source.rs b/hyperactor_mesh/src/log_source.rs index 4023e7e9..17935b11 100644 --- a/hyperactor_mesh/src/log_source.rs +++ b/hyperactor_mesh/src/log_source.rs @@ -22,6 +22,8 @@ use hyperactor::mailbox::DialMailboxRouter; use hyperactor::mailbox::MailboxServer; use hyperactor::proc::Proc; use hyperactor_state::state_actor::StateActor; +use serde::Deserialize; +use serde::Serialize; use crate::shortuuid::ShortUuid; @@ -36,7 +38,7 @@ pub struct LogSource { state_actor: ActorRef, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct StateServerInfo { pub state_proc_addr: ChannelAddr, pub state_actor_id: ActorId, diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index 5ac4ad79..95f6a269 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -56,6 +56,7 @@ use crate::alloc::ProcState; use crate::alloc::ProcStopReason; use crate::assign::Ranks; use crate::comm::CommActorMode; +use crate::log_source::LogSource; use crate::log_source::StateServerInfo; use crate::proc_mesh::mesh_agent::MeshAgent; use crate::proc_mesh::mesh_agent::MeshAgentMessageClient; @@ -91,6 +92,9 @@ pub struct ProcMesh { client: Mailbox, comm_actors: Vec>, world_id: WorldId, + // this is optionally to hold the lifecycle of the state actor for log streaming + // TODO: we can implement the stop so the proc mesh can stop the state actor + _log_source: LogSource, } struct EventState { @@ -303,10 +307,11 @@ impl ProcMesh { } // Get a reference to the state actor for streaming logs. + let log_source = alloc.log_source().await?; let StateServerInfo { state_proc_addr, state_actor_id, - } = alloc.log_source().await?.server_info(); + } = log_source.server_info(); router.bind(state_actor_id.clone().into(), state_proc_addr.clone()); let log_handler = Box::new(hyperactor_state::client::StdlogHandler {}); @@ -357,6 +362,7 @@ impl ProcMesh { client, comm_actors, world_id, + _log_source: log_source, }) } diff --git a/hyperactor_state/src/state_actor.rs b/hyperactor_state/src/state_actor.rs index 606cf36b..34e3a7e2 100644 --- a/hyperactor_state/src/state_actor.rs +++ b/hyperactor_state/src/state_actor.rs @@ -76,10 +76,15 @@ impl StateMessageHandler for StateActor { async fn subscribe_logs( &mut self, - _cx: &Context, + cx: &Context, addr: ChannelAddr, client_actor_ref: ActorRef, ) -> Result<(), anyhow::Error> { + tracing::info!( + "StateActor {} gets a new subscriber: {}", + cx.self_id(), + client_actor_ref + ); self.subscribers .insert(client_actor_ref, create_remote_client(addr).await?); Ok(()) diff --git a/monarch_hyperactor/src/alloc.rs b/monarch_hyperactor/src/alloc.rs index b95b229c..654783e7 100644 --- a/monarch_hyperactor/src/alloc.rs +++ b/monarch_hyperactor/src/alloc.rs @@ -28,6 +28,7 @@ use hyperactor_mesh::alloc::ProcessAllocator; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAlloc; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocHost; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocInitializer; +use hyperactor_mesh::log_source::LogSource; use hyperactor_mesh::shape::Shape; use ndslice::Slice; use pyo3::exceptions::PyRuntimeError; @@ -99,6 +100,10 @@ impl Alloc for PyAllocWrapper { self.inner.transport() } + async fn log_source(&self) -> Result { + self.inner.log_source().await + } + async fn stop(&mut self) -> Result<(), AllocatorError> { self.inner.stop().await } diff --git a/monarch_hyperactor/src/bin/process_allocator/common.rs b/monarch_hyperactor/src/bin/process_allocator/common.rs index f4839e0a..d2325951 100644 --- a/monarch_hyperactor/src/bin/process_allocator/common.rs +++ b/monarch_hyperactor/src/bin/process_allocator/common.rs @@ -64,6 +64,7 @@ mod tests { use hyperactor_mesh::alloc; use hyperactor_mesh::alloc::Alloc; use hyperactor_mesh::alloc::remoteprocess; + use hyperactor_mesh::log_source::LogSource; use ndslice::shape; use super::*; @@ -116,6 +117,10 @@ mod tests { id: serve_address.to_string(), }]) }); + let log_source = LogSource::new_with_local_actor().await.unwrap(); + initializer + .expect_initialize_state_actor() + .return_once(move || Ok(log_source)); let heartbeat = std::time::Duration::from_millis(100); let world_id = WorldId("__unused__".to_string());