From 18995ce752e8251a2507bb43c7e5bb9b95236435 Mon Sep 17 00:00:00 2001 From: Thomas Wang Date: Thu, 10 Jul 2025 09:53:09 -0700 Subject: [PATCH 1/3] No more operational messages Summary: Operational messages were used in order to signal to the simulator that it should perform certain actions like growing or shrinking the mesh. This was needed since the python and rust were running in separate processes, and messages were needed to communicate between the two, but now everything is on the same process so we can do this in memory. Differential Revision: D77941643 --- hyperactor/src/simnet.rs | 166 ++---------------- monarch_extension/src/simulator_client.rs | 150 ++++++++-------- monarch_simulator/Cargo.toml | 7 +- monarch_simulator/src/bootstrap.rs | 66 ------- monarch_simulator/src/main.rs | 54 ------ monarch_simulator/src/simulator.rs | 10 +- .../monarch_extension/simulator_client.pyi | 13 +- python/monarch/sim_mesh.py | 17 +- 8 files changed, 99 insertions(+), 384 deletions(-) delete mode 100644 monarch_simulator/src/main.rs diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 0a87a14f..9ae27d2d 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -33,12 +33,9 @@ use serde::Serialize; use serde::Serializer; use serde_with::serde_as; use tokio::sync::Mutex; -use tokio::sync::SetError; use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; @@ -49,7 +46,6 @@ use crate::ActorId; use crate::Mailbox; use crate::Named; use crate::OncePortRef; -use crate::WorldId; use crate::channel; use crate::channel::ChannelAddr; use crate::channel::Rx; @@ -313,18 +309,6 @@ pub enum SimNetError { #[error("proxy not available: {0}")] ProxyNotAvailable(String), - /// Unable to send message to the simulator. - #[error(transparent)] - OperationalMessageSendError(#[from] SendError), - - /// Setting the operational message sender which is already set. - #[error(transparent)] - OperationalMessageSenderSetError(#[from] SetError>), - - /// Missing OperationalMessageReceiver. - #[error("missing operational message receiver")] - MissingOperationalMessageReceiver, - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -361,8 +345,6 @@ pub struct SimNetHandle { /// Handle to a running proxy server that forwards external messages /// into the simnet. proxy_handle: ProxyHandle, - /// A sender to forward simulator operational messages. - operational_message_tx: UnboundedSender, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -489,82 +471,6 @@ impl SimNetHandle { pub(crate) type Topology = DashMap; -/// The message to spawn a simulated mesh. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct SpawnMesh { - /// The system address. - pub system_addr: ChannelAddr, - /// The controller actor ID. - pub controller_actor_id: ActorId, - /// The worker world. - pub worker_world: WorldId, -} - -impl SpawnMesh { - /// Creates a new SpawnMesh. - pub fn new( - system_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world: WorldId, - ) -> Self { - Self { - system_addr, - controller_actor_id, - worker_world, - } - } -} - -/// An OperationalMessage is a message to control the simulator to do tasks such as -/// spawning or killing actors. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub enum OperationalMessage { - /// Kill the world with given world_id. - KillWorld(String), - /// Spawn actors in a mesh. - SpawnMesh(SpawnMesh), - /// Update training script state. - SetTrainingScriptState(TrainingScriptState), -} - -/// Message Event that can be sent to the simulator. -#[derive(Debug)] -pub struct SimOperation { - /// Sender to send OperationalMessage to the simulator. - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, -} - -impl SimOperation { - /// Creates a new SimOperation. - pub fn new( - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, - ) -> Self { - Self { - operational_message_tx, - operational_message, - } - } -} - -#[async_trait] -impl Event for SimOperation { - async fn handle(&self) -> Result<(), SimNetError> { - self.operational_message_tx - .send(self.operational_message.clone())?; - Ok(()) - } - - fn duration_ms(&self) -> u64 { - 0 - } - - fn summary(&self) -> String { - format!("SimOperation: {:?}", self.operational_message) - } -} - /// A ProxyMessage is a message that SimNet proxy receives. /// The message may requests the SimNet to send the payload in the message field from /// src to dst if addr field exists. @@ -658,7 +564,6 @@ impl ProxyHandle { proxy_addr: ChannelAddr, event_tx: UnboundedSender<(Box, bool, Option)>, pending_event_count: Arc, - operational_message_tx: UnboundedSender, ) -> anyhow::Result { let (addr, mut rx) = channel::serve::(proxy_addr).await?; tracing::info!("SimNet serving external traffic on {}", &addr); @@ -672,26 +577,18 @@ impl ProxyHandle { #[allow(clippy::disallowed_methods)] if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - let event: Box = match proxy_message.dest_addr { - Some(dest_addr) => Box::new(MessageDeliveryEvent::new( + if let Some(dest_addr) = proxy_message.dest_addr { + let event = Box::new(MessageDeliveryEvent::new( proxy_message.sender_addr, dest_addr, proxy_message.data, - )), - None => { - let operational_message: OperationalMessage = - proxy_message.data.deserialized().unwrap(); - Box::new(SimOperation::new( - operational_message_tx.clone(), - operational_message, - )) + )); + if let Err(e) = event_tx.send((event, true, None)) { + tracing::error!("error sending message to simnet: {:?}", e); + } else { + pending_event_count + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); } - }; - - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); } } if stop_signal.load(Ordering::SeqCst) { @@ -729,7 +626,7 @@ pub fn start( private_addr: ChannelAddr, proxy_addr: ChannelAddr, max_duration_ms: u64, -) -> anyhow::Result> { +) -> anyhow::Result<()> { // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); address_book.insert(private_addr.clone()); @@ -775,14 +672,11 @@ pub fn start( .await }) })); - let (operational_message_tx, operational_message_rx) = - mpsc::unbounded_channel::(); let proxy_handle = block_on(ProxyHandle::start( proxy_addr, event_tx.clone(), pending_event_count.clone(), - operational_message_tx.clone(), )) .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; @@ -792,12 +686,11 @@ pub fn start( config, pending_event_count, proxy_handle, - operational_message_tx, training_script_state_tx, stop_signal, }); - Ok(operational_message_rx) + Ok(()) } impl SimNet { @@ -1488,45 +1381,6 @@ edges: assert_eq!(records.unwrap().first().unwrap(), &expected_record); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_operational_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let mut operational_message_rx = start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let spawn_mesh = SpawnMesh { - system_addr: "unix!@system".parse().unwrap(), - controller_actor_id: id!(controller_world[0].actor), - worker_world: id!(worker_world), - }; - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh.clone()); - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let external_message = MessageEnvelope::new_unknown(port_id, serialized_proxy_message); - - // Send the message to the simnet. - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - let received_operational_message = operational_message_rx.recv().await.unwrap(); - - // Check the received message. - assert_eq!(received_operational_message, operational_message); - } - #[tokio::test] async fn test_sim_sleep() { start( diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index 85657d40..eb93c1c5 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -8,24 +8,21 @@ #![cfg(feature = "tensor_engine")] +use std::sync::Arc; + use anyhow::anyhow; -use hyperactor::PortId; -use hyperactor::attrs::Attrs; +use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::Tx; -use hyperactor::channel::dial; -use hyperactor::data::Serialized; -use hyperactor::id; -use hyperactor::mailbox::MessageEnvelope; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::ProxyMessage; -use hyperactor::simnet::SpawnMesh; +use hyperactor::channel::ChannelTransport; +use hyperactor::simnet; use hyperactor::simnet::TrainingScriptState; +use hyperactor::simnet::simnet_handle; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_simulator_lib::bootstrap::bootstrap; +use monarch_simulator_lib::simulator::TensorEngineSimulator; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use tokio::sync::Mutex; /// A wrapper around [ndslice::Slice] to expose it to python. /// It is a compact representation of indices into the flat @@ -39,102 +36,99 @@ use pyo3::prelude::*; )] #[derive(Clone)] pub(crate) struct SimulatorClient { - proxy_addr: ChannelAddr, -} - -fn wrap_operational_message(operational_message: OperationalMessage) -> MessageEnvelope { - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let sender_id = id!(simulator_client[0].sender_actor); - // a dummy port ID. We are delivering message with low level mailbox. - // The port ID is not used. - let port_id = PortId(id!(simulator[0].actor), 0); - MessageEnvelope::new(sender_id, port_id, serialized_proxy_message, Attrs::new()) -} - -#[pyfunction] -fn bootstrap_simulator_backend( - py: Python, - system_addr: String, - proxy_addr: String, - world_size: i32, -) -> PyResult<()> { - signal_safe_block_on(py, async move { - match bootstrap( - system_addr.parse().unwrap(), - proxy_addr.parse().unwrap(), - world_size as usize, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => Err(PyRuntimeError::new_err(err.to_string())), - } - })? + inner: Arc>, + world_size: usize, } -fn set_training_script_state(state: TrainingScriptState, proxy_addr: ChannelAddr) -> PyResult<()> { - let operational_message = OperationalMessage::SetTrainingScriptState(state); - let external_message = wrap_operational_message(operational_message); - let tx = dial(proxy_addr).map_err(|err| anyhow!(err))?; - tx.post(external_message); +fn set_training_script_state(state: TrainingScriptState) -> PyResult<()> { + simnet_handle() + .map_err(|e| anyhow!(e))? + .set_training_script_state(state); Ok(()) } #[pymethods] impl SimulatorClient { #[new] - fn new(proxy_addr: &str) -> PyResult { - Ok(Self { - proxy_addr: proxy_addr + fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { + signal_safe_block_on(py, async move { + let system_addr = system_addr .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?, - }) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let ChannelAddr::Sim(system_sim_addr) = &system_addr else { + return Err(PyValueError::new_err(format!( + "bootstrap address should be a sim address: {}", + system_addr + ))); + }; + + simnet::start( + ChannelAddr::any(ChannelTransport::Unix), + system_sim_addr.proxy().clone(), + 1000, + ) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + + Ok(Self { + inner: Arc::new(Mutex::new( + TensorEngineSimulator::new(system_addr) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + )), + world_size: world_size as usize, + }) + })? } - fn kill_world(&self, world_name: &str) -> PyResult<()> { - let operational_message = OperationalMessage::KillWorld(world_name.to_string()); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + fn kill_world(&self, py: Python, world_name: &str) -> PyResult<()> { + let simulator = self.inner.clone(); + let world_name = world_name.to_string(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .kill_world(&world_name) + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn spawn_mesh( &self, + py: Python, system_addr: &str, controller_actor_id: &str, worker_world: &str, ) -> PyResult<()> { - let spawn_mesh = SpawnMesh::new( - system_addr.parse().unwrap(), - controller_actor_id.parse().unwrap(), - worker_world.parse().unwrap(), - ); - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + let simulator = self.inner.clone(); + let world_size = self.world_size; + let system_addr = system_addr.parse::().unwrap(); + let worker_world = worker_world.parse::().unwrap(); + let controller_actor_id = controller_actor_id.parse().unwrap(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) + .await + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn set_training_script_state_running(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Running, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Running) } fn set_training_script_state_waiting(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Waiting, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Waiting) } } pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { simulator_client_mod.add_class::()?; - let f = wrap_pyfunction!(bootstrap_simulator_backend, simulator_client_mod)?; - f.setattr( - "__module__", - "monarch._rust_bindings.monarch_extension.simulator_client", - )?; - simulator_client_mod.add_function(f)?; Ok(()) } diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml index 863b5384..e28372ed 100644 --- a/monarch_simulator/Cargo.toml +++ b/monarch_simulator/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/monarch_simulator:[monarch_simulator,monarch_simulator_lib] +# @generated by autocargo from //monarch/monarch_simulator:monarch_simulator_lib [package] name = "monarch_simulator_lib" @@ -7,14 +7,9 @@ authors = ["Meta"] edition = "2021" license = "BSD-3-Clause" -[[bin]] -name = "monarch_simulator" -path = "src/main.rs" - [dependencies] anyhow = "1.0.98" async-trait = "0.1.86" -clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller" } dashmap = { version = "5.5.3", features = ["rayon", "serde"] } futures = { version = "0.3.30", features = ["async-await", "compat"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index e946fd63..6ba3000d 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -20,23 +20,15 @@ use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; -use hyperactor::simnet; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::SpawnMesh; -use hyperactor::simnet::simnet_handle; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; use hyperactor_multiprocess::proc_actor::spawn; use hyperactor_multiprocess::system::ServerHandle; use hyperactor_multiprocess::system_actor::ProcLifecycleMode; use monarch_messages::worker::Factory; -use tokio::sync::Mutex; -use tokio::sync::mpsc::UnboundedReceiver; -use tokio::task::JoinHandle; use torch_sys::Layout; use torch_sys::ScalarType; -use crate::simulator::Simulator; use crate::worker::Fabric; use crate::worker::MockWorkerParams; use crate::worker::WorkerActor; @@ -184,61 +176,3 @@ pub async fn spawn_sim_worker( .await?; Ok(bootstrap.proc_actor) } - -/// Bootstrap the simulation. Spawns the system, controllers, and workers. -/// Args: -/// system_addr: The address of the system actor. -pub async fn bootstrap( - system_addr: ChannelAddr, - proxy_addr: ChannelAddr, - world_size: usize, -) -> Result> { - // TODO: enable supervision events. - let mut operational_message_rx = simnet::start(system_addr.clone(), proxy_addr, 1000)?; - let simulator = Arc::new(Mutex::new(Simulator::new(system_addr).await?)); - let operational_listener_handle = { - let simulator = simulator.clone(); - tokio::spawn(async move { - handle_operational_message(&mut operational_message_rx, simulator, world_size).await - }) - }; - - Ok(operational_listener_handle) -} - -async fn handle_operational_message( - operational_message_rx: &mut UnboundedReceiver, - simulator: Arc>, - world_size: usize, -) { - while let Some(msg) = operational_message_rx.recv().await { - tracing::info!("received operational message: {:?}", msg); - match msg { - OperationalMessage::SpawnMesh(SpawnMesh { - system_addr, - controller_actor_id, - worker_world, - }) => { - if let Err(e) = simulator - .lock() - .await - .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) - .await - { - tracing::error!("failed to spawn mesh: {:?}", e); - } - } - OperationalMessage::KillWorld(world_id) => { - if let Err(e) = simulator.lock().await.kill_world(&world_id) { - tracing::error!("failed to kill world: {:?}", e); - } - } - OperationalMessage::SetTrainingScriptState(state) => match simnet_handle() { - Ok(handle) => handle.set_training_script_state(state), - Err(e) => { - tracing::error!("failed to set training script state: {:?}", e); - } - }, - } - } -} diff --git a/monarch_simulator/src/main.rs b/monarch_simulator/src/main.rs deleted file mode 100644 index 12169d91..00000000 --- a/monarch_simulator/src/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! A binary to launch the simulated Monarch controller along with necessary environment. -use std::process::ExitCode; - -use anyhow::Result; -use clap::Parser; -use hyperactor::channel::ChannelAddr; -use monarch_simulator_lib::bootstrap::bootstrap; - -#[derive(Debug, Parser)] -struct Args { - #[arg(short, long)] - system_addr: ChannelAddr, - #[arg(short, long)] - proxy_addr: ChannelAddr, -} - -const TITLE: &str = r#" -****************************************************** -* * -* ____ ___ __ __ _ _ _ _ _____ ___ ____ * -*/ ___|_ _| \/ | | | | | / \|_ _/ _ \| _ \ * -*\___ \| || |\/| | | | | | / _ \ | || | | | |_) |* -* ___) | || | | | |_| | |___ / ___ \| || |_| | _ < * -*|____/___|_| |_|\___/|_____/_/ \_\_| \___/|_| \_\* -* * -****************************************************** -"#; - -#[tokio::main] -async fn main() -> Result { - eprintln!("{}", TITLE); - hyperactor::initialize_with_current_runtime(); - let args = Args::parse(); - - let system_addr = args.system_addr.clone(); - let proxy_addr = args.proxy_addr.clone(); - tracing::info!("starting Monarch simulation"); - - let operational_listener_handle = bootstrap(system_addr, proxy_addr, 1).await?; - - operational_listener_handle - .await - .expect("simulator exited unexpectedly"); - - Ok(ExitCode::SUCCESS) -} diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 10fa9fd7..97c02b77 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -26,13 +26,13 @@ use crate::bootstrap::spawn_system; /// The simulator manages all of the meshes and the system handle. #[derive(Debug)] -pub struct Simulator { +pub struct TensorEngineSimulator { /// A map from world name to actor handles in that world. worlds: HashMap>>, system_handle: ServerHandle, } -impl Simulator { +impl TensorEngineSimulator { pub async fn new(system_addr: ChannelAddr) -> Result { Ok(Self { worlds: HashMap::new(), @@ -92,7 +92,7 @@ impl Simulator { /// IntoFuture allows users to await the handle. The future resolves when /// the simulator itself has all of the actor handles stopped. -impl IntoFuture for Simulator { +impl IntoFuture for TensorEngineSimulator { type Output = (); type IntoFuture = BoxFuture<'static, Self::Output>; @@ -145,7 +145,9 @@ mod tests { let system_addr = format!("sim!unix!@system,{}", &proxy) .parse::() .unwrap(); - let mut simulator = super::Simulator::new(system_addr.clone()).await.unwrap(); + let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) + .await + .unwrap(); let mut controller_actor_ids = vec![]; let mut worker_actor_ids = vec![]; let n_meshes = 2; diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi index 9cea9d93..6b06c274 100644 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi @@ -13,10 +13,11 @@ class SimulatorClient: It is a client to communicate with the simulator service. Arguments: - - `proxy_addr`: Address of the simulator's proxy server. + - `system_addr`: Address of the system. + - `world_size`: Number of workers in a given mesh. """ - def __init__(self, proxy_addr: str) -> None: ... + def __init__(self, system_addr: str, world_size: int) -> None: ... def kill_world(self, world_name: str) -> None: """ Kill the world with the given name. @@ -51,11 +52,3 @@ class SimulatorClient: backend to resolve a future """ ... - -def bootstrap_simulator_backend( - system_addr: str, proxy_addr: str, world_size: int -) -> None: - """ - Bootstrap the simulator backend on the current process - """ - ... diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index b91361bf..8c304efd 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -31,7 +31,6 @@ ) from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension - bootstrap_simulator_backend, SimulatorClient, ) @@ -76,7 +75,6 @@ def sim_mesh( bootstrap: Bootstrap = Bootstrap( n_meshes, mesh_world_state, - proxy_addr=proxy_addr, world_size=hosts * gpus_per_host, ) @@ -181,7 +179,6 @@ def __init__( self, num_meshes: int, mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - proxy_addr: Optional[str] = None, world_size: int = 1, ) -> None: """ @@ -199,17 +196,15 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy" + proxy_addr = f"unix!@{_random_id()}-proxy" self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr: str = ( + self.client_listen_addr = f"sim!unix!@client,{client_proxy_addr}" + self.client_bootstrap_addr = ( f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" ) - bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size) - self._simulator_client = SimulatorClient(proxy_addr) + self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): mesh_name: str = f"mesh_{i}" controller_world: str = f"{mesh_name}_controller" @@ -235,7 +230,9 @@ def spawn_mesh(self, mesh_world: MeshWorld) -> None: worker_world, controller_id = mesh_world controller_world = controller_id.world_name self._simulator_client.spawn_mesh( - self.bootstrap_addr, f"{controller_world}[0].root", worker_world + self.bootstrap_addr, + f"{controller_world}[0].root", + worker_world, ) From e5e9fee886ccc229f84708870b815a6b9c393cc9 Mon Sep 17 00:00:00 2001 From: Thomas Wang Date: Thu, 10 Jul 2025 09:53:09 -0700 Subject: [PATCH 2/3] Remove proxy Summary: The proxy was previously used since the simulator and the python were run in separate processes. Since both are now run in the same process we no longer need a proxy Differential Revision: D77941641 --- controller/src/lib.rs | 16 +- hyperactor/src/channel.rs | 11 +- hyperactor/src/channel/sim.rs | 256 +++++---------- hyperactor/src/clock.rs | 7 +- hyperactor/src/mailbox.rs | 18 +- hyperactor/src/simnet.rs | 340 ++------------------ hyperactor_multiprocess/src/ping_pong.rs | 36 +-- hyperactor_multiprocess/src/system_actor.rs | 23 +- monarch_extension/src/simulation_tools.rs | 10 +- monarch_extension/src/simulator_client.rs | 28 +- monarch_hyperactor/src/channel.rs | 6 +- monarch_simulator/src/bootstrap.rs | 21 +- monarch_simulator/src/simulator.rs | 13 +- monarch_simulator/src/worker.rs | 7 +- python/monarch/sim_mesh.py | 15 +- python/tests/test_sim_backend.py | 5 +- 16 files changed, 156 insertions(+), 656 deletions(-) diff --git a/controller/src/lib.rs b/controller/src/lib.rs index a2d1f54d..36755094 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -1559,20 +1559,13 @@ mod tests { #[tokio::test] async fn test_sim_supervision_failure() { // Start system actor. - let system_addr = ChannelAddr::any(ChannelTransport::Unix); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); simnet::simnet_handle() .unwrap() .set_training_script_state(simnet::TrainingScriptState::Waiting); let system_sim_addr = - ChannelAddr::Sim(SimAddr::new(system_addr, proxy_addr.clone()).unwrap()); + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); // Set very long supervision_update_timeout let server_handle = System::serve( system_sim_addr.clone(), @@ -1588,9 +1581,8 @@ mod tests { // Bootstrap the controller let controller_id = id!(controller[0].root); let proc_id = id!(world[0]); - let controller_proc_listen_addr = ChannelAddr::Sim( - SimAddr::new(ChannelAddr::any(ChannelTransport::Unix), proxy_addr).unwrap(), - ); + let controller_proc_listen_addr = + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let (_, actor_ref) = ControllerActor::bootstrap( controller_id.clone(), diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index b98ec1f7..cdf9e662 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -236,7 +236,7 @@ pub enum ChannelTransport { Local, /// Sim is a simulated channel for testing. - Sim(/*proxy address:*/ ChannelAddr), + Sim(/*simulated transport:*/ Box), /// Transport over unix domain socket. Unix, @@ -368,7 +368,7 @@ impl ChannelAddr { Self::MetaTls(hostname, 0) } ChannelTransport::Local => Self::Local(0), - ChannelTransport::Sim(proxy) => sim::any(proxy), + ChannelTransport::Sim(transport) => sim::any(*transport), // This works because the file will be deleted but we know we have a unique file by this point. ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()), } @@ -380,7 +380,7 @@ impl ChannelAddr { Self::Tcp(_) => ChannelTransport::Tcp, Self::MetaTls(_, _) => ChannelTransport::MetaTls, Self::Local(_) => ChannelTransport::Local, - Self::Sim(addr) => ChannelTransport::Sim(addr.proxy().clone()), + Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, } } @@ -637,10 +637,9 @@ mod tests { } for (raw, parsed) in cases_ok.iter().zip(src_ok.clone()).map(|(a, _)| { - let proxy_str = "unix!@proxy_a"; ( - format!("sim!{},{}", a.0, &proxy_str), - ChannelAddr::Sim(SimAddr::new(a.1.clone(), proxy_str.parse().unwrap()).unwrap()), + format!("sim!{}", a.0), + ChannelAddr::Sim(SimAddr::new(a.1.clone()).unwrap()), ) }) { assert_eq!(raw.parse::().unwrap(), parsed); diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index be97ca4b..73a663a5 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -39,25 +39,6 @@ lazy_static! { } static SIM_LINK_BUF_SIZE: usize = 256; -#[derive( - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Ord, - PartialOrd, - Hash -)] -/// A channel address along with the address of the proxy for the process -pub struct AddressProxyPair { - /// The address. - pub address: ChannelAddr, - /// The address of the proxy for the process - pub proxy: ChannelAddr, -} - /// An address for a simulated channel. #[derive( Clone, @@ -71,36 +52,38 @@ pub struct AddressProxyPair { Hash )] pub struct SimAddr { - src: Option>, + src: Option>, /// The address. addr: Box, - /// The proxy address. - proxy: Box, + /// If source is the client + client: bool, } impl SimAddr { /// Creates a new SimAddr. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. /// Creates a new SimAddr without a source to be served - pub fn new(addr: ChannelAddr, proxy: ChannelAddr) -> Result { - Self::new_impl(None, addr, proxy) + pub fn new(addr: ChannelAddr) -> Result { + Self::new_impl(None, addr, false) } /// Creates a new directional SimAddr meant to convey a channel between two addresses. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. - pub fn new_with_src( - src: AddressProxyPair, - addr: ChannelAddr, - proxy: ChannelAddr, - ) -> Result { - Self::new_impl(Some(Box::new(src)), addr, proxy) + pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, false) + } + + /// Creates a new directional SimAddr meant to convey a channel between two addresses. + #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. + fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, true) } #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. fn new_impl( - src: Option>, + src: Option>, addr: ChannelAddr, - proxy: ChannelAddr, + client: bool, ) -> Result { if let ChannelAddr::Sim(_) = &addr { return Err(SimNetError::InvalidArg(format!( @@ -108,16 +91,10 @@ impl SimAddr { addr ))); } - if let ChannelAddr::Sim(_) = &proxy { - return Err(SimNetError::InvalidArg(format!( - "proxy cannot be a sim address, found {}", - proxy - ))); - } Ok(Self { src, addr: Box::new(addr), - proxy: Box::new(proxy), + client, }) } @@ -126,26 +103,22 @@ impl SimAddr { &self.addr } - /// Returns the proxy address. - pub fn proxy(&self) -> &ChannelAddr { - &self.proxy + /// Returns the source address + pub fn src(&self) -> &Option> { + &self.src } - /// Returns the source address and proxy. - pub fn src(&self) -> &Option> { - &self.src + /// The underlying transport we are simulating + pub fn transport(&self) -> ChannelTransport { + self.addr.transport() } } impl fmt::Display for SimAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.src { - None => write!(f, "{},{}", self.addr, self.proxy), - Some(src) => write!( - f, - "{},{},{},{}", - src.address, src.proxy, self.addr, self.proxy - ), + None => write!(f, "{}", self.addr), + Some(src) => write!(f, "{},{}", src, self.addr), } } } @@ -153,19 +126,15 @@ impl fmt::Display for SimAddr { /// Message Event that can be passed around in the simnet. #[derive(Debug)] pub(crate) struct MessageDeliveryEvent { - src_addr: Option, - dest_addr: AddressProxyPair, + src_addr: Option, + dest_addr: ChannelAddr, data: Serialized, duration_ms: u64, } impl MessageDeliveryEvent { /// Creates a new MessageDeliveryEvent. - pub fn new( - src_addr: Option, - dest_addr: AddressProxyPair, - data: Serialized, - ) -> Self { + pub fn new(src_addr: Option, dest_addr: ChannelAddr, data: Serialized) -> Self { Self { src_addr, dest_addr, @@ -198,16 +167,16 @@ impl Event for MessageDeliveryEvent { "Sending message from {} to {}", self.src_addr .as_ref() - .map_or("unknown".to_string(), |addr| addr.address.to_string()), - self.dest_addr.address.clone() + .map_or("unknown".to_string(), |addr| addr.to_string()), + self.dest_addr.clone() ) } async fn read_simnet_config(&mut self, topology: &Arc>) { if let Some(src_addr) = &self.src_addr { let edge = SimNetEdge { - src: src_addr.address.clone(), - dst: self.dest_addr.address.clone(), + src: src_addr.clone(), + dst: self.dest_addr.clone(), }; self.duration_ms = topology .lock() @@ -232,18 +201,18 @@ pub async fn update_config(config: simnet::NetworkConfig) -> anyhow::Result<(), } /// Returns a simulated channel address that is bound to "any" channel address. -pub(crate) fn any(proxy: ChannelAddr) -> ChannelAddr { +pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr { ChannelAddr::Sim(SimAddr { src: None, - addr: Box::new(ChannelAddr::any(proxy.transport())), - proxy: Box::new(proxy), + addr: Box::new(ChannelAddr::any(transport)), + client: false, }) } /// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`. pub fn parse(addr_string: &str) -> Result { - let re = Regex::new(r"([^,]+),([^,]+)(,([^,]+),([^,]+))?$").map_err(|err| { + let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| { ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err)) })?; @@ -261,26 +230,19 @@ pub fn parse(addr_string: &str) -> Result { } match parts.len() { - 2 => { + 1 => { let addr = parts[0].parse::()?; - let proxy = parts[1].parse::()?; - Ok(ChannelAddr::Sim(SimAddr::new(addr, proxy)?)) + Ok(ChannelAddr::Sim(SimAddr::new(addr)?)) } - 5 => { + 3 => { let src_addr = parts[0].parse::()?; - let src_proxy = parts[1].parse::()?; - let addr = parts[3].parse::()?; - let proxy = parts[4].parse::()?; - - Ok(ChannelAddr::Sim(SimAddr::new_with_src( - AddressProxyPair { - address: src_addr, - proxy: src_proxy, - }, - addr, - proxy, - )?)) + let addr = parts[2].parse::()?; + Ok(ChannelAddr::Sim(if parts[0] == "client" { + SimAddr::new_with_client_src(src_addr, addr) + } else { + SimAddr::new_with_src(src_addr, addr) + }?)) } _ => Err(ChannelError::InvalidAddress(addr_string.to_string())), } @@ -310,31 +272,22 @@ fn create_egress_sender( Ok(Arc::new(tx)) } -/// Check if the address is outside of the simulation. -#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. -fn is_external_addr(addr: &AddressProxyPair) -> anyhow::Result { - Ok(simnet_handle()?.proxy_addr() != &addr.proxy) -} - #[async_trait] -impl Dispatcher for SimDispatcher { +impl Dispatcher for SimDispatcher { async fn send( &self, - _src_addr: Option, - addr: AddressProxyPair, + _src_addr: Option, + addr: ChannelAddr, data: Serialized, ) -> Result<(), SimNetError> { self.dispatchers - .get(&addr.address) + .get(&addr) .ok_or_else(|| { - SimNetError::InvalidNode( - addr.address.to_string(), - anyhow::anyhow!("no dispatcher found"), - ) + SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found")) })? .send(data) .await - .map_err(|err| SimNetError::InvalidNode(addr.address.to_string(), err.into())) + .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into())) } } @@ -349,9 +302,10 @@ impl Default for SimDispatcher { #[derive(Debug)] pub(crate) struct SimTx { - src_addr: Option, - dst_addr: AddressProxyPair, + src_addr: Option, + dst_addr: ChannelAddr, status: watch::Receiver, // Default impl. Always reports `Active`. + client: bool, _phantom: PhantomData, } @@ -372,15 +326,14 @@ impl Tx for SimTx { }; match simnet_handle() { Ok(handle) => match &self.src_addr { - Some(src_addr) if src_addr.proxy != *handle.proxy_addr() => handle - .send_scheduled_event(ScheduledEvent { - event: Box::new(MessageDeliveryEvent::new( - self.src_addr.clone(), - self.dst_addr.clone(), - data, - )), - time: SimClock.millis_since_start(RealClock.now()), - }), + Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent { + event: Box::new(MessageDeliveryEvent::new( + self.src_addr.clone(), + self.dst_addr.clone(), + data, + )), + time: SimClock.millis_since_start(RealClock.now()), + }), _ => handle.send_event(Box::new(MessageDeliveryEvent::new( self.src_addr.clone(), self.dst_addr.clone(), @@ -393,7 +346,7 @@ impl Tx for SimTx { } fn addr(&self) -> ChannelAddr { - self.dst_addr.address.clone() + self.dst_addr.clone() } fn status(&self) -> &watch::Receiver { @@ -412,11 +365,9 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE Ok(SimTx { src_addr: dialer, - dst_addr: AddressProxyPair { - address: *addr.addr, - proxy: *addr.proxy, - }, + dst_addr: addr.addr().clone(), status, + client: addr.client, _phantom: PhantomData, }) } @@ -425,13 +376,10 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE /// The mpsc tx will be used to dispatch messages when it's time while /// the mpsc rx will be used by the above applications to handle received messages /// like any other channel. -/// A sim address has src and dst. Dispatchers are only indexed by dst address. +/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address. pub(crate) fn serve( sim_addr: SimAddr, ) -> anyhow::Result<(ChannelAddr, SimRx)> { - // Serves sim address at sim_addr.src and set up local proxy at sim_addr.src_proxy. - // Reversing the src and dst since the first element in the output tuple is the - // dialing address of this sim channel. So the served address is the dst. let (tx, rx) = mpsc::channel::(SIM_LINK_BUF_SIZE); // Add tx to sender dispatch. SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx); @@ -474,13 +422,7 @@ mod tests { let dst_ok = vec!["[::1]:1234", "tcp!127.0.0.1:8080", "local!123"]; let srcs_ok = vec!["[::2]:1234", "tcp!127.0.0.2:8080", "local!124"]; - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); // TODO: New NodeAdd event should do this for you.. for addr in dst_ok.iter().chain(srcs_ok.iter()) { @@ -490,16 +432,10 @@ mod tests { .bind(addr.parse::().unwrap()) .unwrap(); } - // Messages are transferred internally if only there's a local proxy and the - // dst proxy is the same as local proxy. for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) { let dst_addr = SimAddr::new_with_src( - AddressProxyPair { - address: src_addr.parse::().unwrap(), - proxy: proxy.clone(), - }, + src_addr.parse::().unwrap(), dst_addr.parse::().unwrap(), - proxy.clone(), ) .unwrap(); @@ -517,22 +453,14 @@ mod tests { async fn test_invalid_sim_addr() { let src = "sim!src"; let dst = "sim!dst"; - let src_proxy = "sim!src_proxy"; - let dst_proxy = "sim!dst_proxy"; - let sim_addr = format!("{},{},{},{}", src, src_proxy, dst, dst_proxy); + let sim_addr = format!("{},{}", src, dst); let result = parse(&sim_addr); assert!(matches!(result, Err(ChannelError::InvalidAddress(_)))); - - let dst = "unix!dst".parse::().unwrap(); - let dst_proxy = "sim!unix!a,unix!b".parse::().unwrap(); - let result = SimAddr::new(dst, dst_proxy); - // dst_proxy shouldn't be a sim address. - assert!(matches!(result, Err(SimNetError::InvalidArg(_)))); } #[tokio::test] async fn test_parse_sim_addr() { - let sim_addr = "sim!unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { @@ -540,42 +468,26 @@ mod tests { }; assert!(sim_addr.src().is_none()); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); - let sim_addr = "sim!unix!@src,unix!@proxy,unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@src,unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { panic!("Expected a sim address"); }; assert!(sim_addr.src().is_some()); - let src_pair = sim_addr.src().clone().unwrap(); - assert_eq!(src_pair.address.to_string(), "unix!@src"); - assert_eq!(src_pair.proxy.to_string(), "unix!@proxy"); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); } #[tokio::test] async fn test_realtime_frontier() { - let proxy: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); tokio::time::pause(); - let sim_addr = - SimAddr::new("unix!@dst".parse::().unwrap(), proxy.clone()).unwrap(); + let sim_addr = SimAddr::new("unix!@dst".parse::().unwrap()).unwrap(); let sim_addr_with_src = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@src".parse::().unwrap(), - proxy: proxy.clone(), - }, + "unix!@src".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy.clone(), ) .unwrap(); let (_, mut rx) = sim::serve::<()>(sim_addr.clone()).unwrap(); @@ -612,31 +524,17 @@ mod tests { #[tokio::test] async fn test_client_message_scheduled_realtime() { tokio::time::pause(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + start(); let controller_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@controller".parse::().unwrap(), - proxy: proxy_addr.clone(), - }, + "unix!@controller".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let controller_tx = sim::dial::<()>(controller_to_dst.clone()).unwrap(); - let client_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: ChannelAddr::any(ChannelTransport::Unix), - proxy: ChannelAddr::any(ChannelTransport::Unix), - }, + let client_to_dst = SimAddr::new_with_client_src( + "unix!@client".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let client_tx = sim::dial::<()>(client_to_dst).unwrap(); diff --git a/hyperactor/src/clock.rs b/hyperactor/src/clock.rs index 9c8fcbc4..b5d45bdb 100644 --- a/hyperactor/src/clock.rs +++ b/hyperactor/src/clock.rs @@ -341,12 +341,7 @@ mod tests { #[tokio::test] async fn test_sim_timeout() { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let res = SimClock .timeout(tokio::time::Duration::from_secs(10), async { SimClock.sleep(tokio::time::Duration::from_secs(5)).await; diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 0f0c33a3..e2e7d5a2 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -2359,7 +2359,6 @@ mod tests { use crate::channel::ChannelTransport; use crate::channel::dial; use crate::channel::serve; - use crate::channel::sim::AddressProxyPair; use crate::channel::sim::SimAddr; use crate::clock::Clock; use crate::clock::RealClock; @@ -2560,23 +2559,12 @@ mod tests { #[tokio::test] async fn test_sim_client_server() { - let proxy = ChannelAddr::any(channel::ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); - let dst_addr = - SimAddr::new("local!1".parse::().unwrap(), proxy.clone()).unwrap(); + simnet::start(); + let dst_addr = SimAddr::new("local!1".parse::().unwrap()).unwrap(); let src_to_dst = ChannelAddr::Sim( SimAddr::new_with_src( - AddressProxyPair { - address: "local!0".parse::().unwrap(), - proxy: proxy.clone(), - }, + "local!0".parse::().unwrap(), dst_addr.addr().clone(), - dst_addr.proxy().clone(), ) .unwrap(), ); diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 9ae27d2d..0e0224fa 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -26,7 +26,6 @@ use async_trait::async_trait; use dashmap::DashMap; use dashmap::DashSet; use enum_as_inner::EnumAsInner; -use futures::executor::block_on; use serde::Deserialize; use serde::Deserializer; use serde::Serialize; @@ -39,23 +38,16 @@ use tokio::sync::mpsc::UnboundedSender; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; -use tokio::time::timeout; use crate as hyperactor; // for macros use crate::ActorId; use crate::Mailbox; -use crate::Named; use crate::OncePortRef; -use crate::channel; use crate::channel::ChannelAddr; -use crate::channel::Rx; -use crate::channel::sim::AddressProxyPair; -use crate::channel::sim::MessageDeliveryEvent; use crate::clock::Clock; use crate::clock::RealClock; use crate::clock::SimClock; use crate::data::Serialized; -use crate::mailbox::MessageEnvelope; static HANDLE: OnceLock = OnceLock::new(); @@ -305,10 +297,6 @@ pub enum SimNetError { #[error("timeout after {} ms: {}", .0.as_millis(), .1)] Timeout(Duration, String), - /// External node is trying to connect but proxy is not available. - #[error("proxy not available: {0}")] - ProxyNotAvailable(String), - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -342,9 +330,6 @@ pub struct SimNetHandle { event_tx: UnboundedSender<(Box, bool, Option)>, config: Arc>, pending_event_count: Arc, - /// Handle to a running proxy server that forwards external messages - /// into the simnet. - proxy_handle: ProxyHandle, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -416,8 +401,6 @@ impl SimNetHandle { /// Close the simulator, processing pending messages before /// completing the returned future. pub async fn close(&self) -> Result, JoinError> { - // Stop the proxy if there is one. - self.proxy_handle.stop().await?; // Signal the simnet loop to stop self.stop_signal.store(true, Ordering::SeqCst); @@ -462,42 +445,10 @@ impl SimNetHandle { "timeout waiting for received events to be scheduled".to_string(), )) } - - /// Returns the external address of the simnet. - pub fn proxy_addr(&self) -> &ChannelAddr { - &self.proxy_handle.addr - } } pub(crate) type Topology = DashMap; -/// A ProxyMessage is a message that SimNet proxy receives. -/// The message may requests the SimNet to send the payload in the message field from -/// src to dst if addr field exists. -/// Or handle the payload in the message field if addr field is None, indicating that -/// this is a self-handlable message. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub struct ProxyMessage { - sender_addr: Option, - dest_addr: Option, - data: Serialized, -} - -impl ProxyMessage { - /// Creates a new ForwardMessage. - pub fn new( - sender_addr: Option, - dest_addr: Option, - data: Serialized, - ) -> Self { - Self { - sender_addr, - dest_addr, - data, - } - } -} - /// Configure network topology for the simnet pub struct SimNetConfig { // For now, we assume the network is fully connected @@ -510,31 +461,6 @@ pub struct SimNetConfig { /// The network is represented as a graph of nodes. /// The graph is represented as a map of edges. /// The network also has a cloud of inflight messages -/// SimNet also serves a proxy address to receive external traffic. This proxy address can handle -/// [`ProxyMessage`]s and forward the payload from src to dst. -/// -/// Example: -/// In this example, we send a ForwardMessage to the proxy_addr. SimNet will handle the message and -/// forward the payload from src to dst. -/// ```ignore -/// let nw_handle = start("local!0".parse().unwrap(), 1000, true, Some(gen_event_fcn)) -/// .await -/// .unwrap(); -/// let proxy_addr = nw_handle.proxy_addr().clone(); -/// let tx = crate::channel::dial(proxy_addr).unwrap(); -/// let src_to_dst_msg = MessageEnvelope::new_unknown( -/// port_id.clone(), -/// Serialized::serialize(&"hola".to_string()).unwrap(), -/// ); -/// let forward_message = ForwardMessage::new( -/// "unix!@src".parse::().unwrap(), -/// "unix!@dst".parse::().unwrap(), -/// src_to_dst_msg -/// ); -/// let external_message = -/// MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); -/// tx.send(external_message).await.unwrap(); -/// ``` pub struct SimNet { config: Arc>, address_book: DashSet, @@ -545,103 +471,15 @@ pub struct SimNet { pending_event_count: Arc, } -/// A proxy to bridge external nodes and the SimNet. -struct ProxyHandle { - join_handle: Mutex>>, - stop_signal: Arc, - addr: ChannelAddr, -} - -impl ProxyHandle { - /// Starts an proxy server to handle external [`ForwardMessage`]s. It will forward the payload inside - /// the [`ForwardMessage`] from src to dst in the SimNet. - /// Args: - /// proxy_addr: address to listen - /// event_tx: a channel to send events to the SimNet - /// pending_event_count: a counter to keep track of the number of pending events - /// to_event: a function that specifies how to generate an Event from a forward message - async fn start( - proxy_addr: ChannelAddr, - event_tx: UnboundedSender<(Box, bool, Option)>, - pending_event_count: Arc, - ) -> anyhow::Result { - let (addr, mut rx) = channel::serve::(proxy_addr).await?; - tracing::info!("SimNet serving external traffic on {}", &addr); - let stop_signal = Arc::new(AtomicBool::new(false)); - - let join_handle = { - let stop_signal = stop_signal.clone(); - tokio::spawn(async move { - 'outer: loop { - // timeout the wait to enable stop signal checking at least every 100ms. - #[allow(clippy::disallowed_methods)] - if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { - let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - if let Some(dest_addr) = proxy_message.dest_addr { - let event = Box::new(MessageDeliveryEvent::new( - proxy_message.sender_addr, - dest_addr, - proxy_message.data, - )); - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - } - } - if stop_signal.load(Ordering::SeqCst) { - eprintln!("stopping external traffic handler"); - break 'outer; - } - } - }) - }; - Ok(Self { - join_handle: Mutex::new(Some(join_handle)), - stop_signal, - addr, - }) - } - - /// Stop the proxy. - async fn stop(&self) -> Result<(), JoinError> { - self.stop_signal.store(true, Ordering::SeqCst); - let mut guard = self.join_handle.lock().await; - if let Some(handle) = guard.take() { - handle.await - } else { - Ok(()) - } - } -} - /// Starts a sim net. /// Args: -/// private_addr: an internal address to receive operational messages such as NodeJoinEvent /// max_duration_ms: an optional config to override default settings of the network latency -/// enable_record: a flag to enable recording of message delivery records -pub fn start( - private_addr: ChannelAddr, - proxy_addr: ChannelAddr, - max_duration_ms: u64, -) -> anyhow::Result<()> { +pub fn start() { + let max_duration_ms = 1000 * 10; // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); - address_book.insert(private_addr.clone()); let topology = DashMap::new(); - topology.insert( - SimNetEdge { - src: private_addr.clone(), - dst: private_addr, - }, - SimNetEdgeInfo { - latency: Duration::from_millis(1), - }, - ); - let config = Arc::new(Mutex::new(SimNetConfig { topology })); let (training_script_state_tx, training_script_state_rx) = @@ -657,7 +495,7 @@ pub fn start( let stop_signal = stop_signal.clone(); tokio::spawn(async move { - let mut net = SimNet { + SimNet { config, address_book, state: State { @@ -667,30 +505,20 @@ pub fn start( max_latency: Duration::from_millis(max_duration_ms), records: Vec::new(), pending_event_count, - }; - net.run(event_rx, training_script_state_rx, stop_signal) - .await + } + .run(event_rx, training_script_state_rx, stop_signal) + .await }) })); - let proxy_handle = block_on(ProxyHandle::start( - proxy_addr, - event_tx.clone(), - pending_event_count.clone(), - )) - .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; - HANDLE.get_or_init(|| SimNetHandle { join_handle, event_tx, config, pending_event_count, - proxy_handle, training_script_state_tx, stop_signal, }); - - Ok(()) } impl SimNet { @@ -1081,17 +909,7 @@ mod tests { #[tokio::test] async fn test_handle_instantiation() { - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - assert!( - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .is_ok() - ); + start(); simnet_handle().unwrap().close().await.unwrap(); } @@ -1099,12 +917,7 @@ mod tests { async fn test_simnet_config() { // Tests that we can create a simnet, config latency between two node and deliver // the message with configured latency. - start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); let latency = Duration::from_millis(1000); @@ -1121,9 +934,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr.clone()).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); let msg = Box::new(MessageDeliveryEvent::new( alice, bob, @@ -1149,12 +961,7 @@ mod tests { #[tokio::test] async fn test_simnet_debounce() { let default_addr = "local!0".parse::().unwrap(); - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); @@ -1171,10 +978,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); // Rapidly send 10 messages expecting that each one debounces the processing for _ in 0..10 { @@ -1211,13 +1016,7 @@ mod tests { #[tokio::test] async fn test_sim_dispatch() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); let sender = Some(TestDispatcher::default()); let mut addresses: Vec = Vec::new(); // // Create a simple network of 4 nodes. @@ -1234,11 +1033,10 @@ mod tests { .map(|s| Serialized::serialize(&s.to_string()).unwrap()) .collect(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let addr_0 = SimAddr::new(addresses[0].clone(), proxy_addr.clone()).unwrap(); - let addr_1 = SimAddr::new(addresses[1].clone(), proxy_addr.clone()).unwrap(); - let addr_2 = SimAddr::new(addresses[2].clone(), proxy_addr.clone()).unwrap(); - let addr_3 = SimAddr::new(addresses[3].clone(), proxy_addr.clone()).unwrap(); + let addr_0 = SimAddr::new(addresses[0].clone()).unwrap(); + let addr_1 = SimAddr::new(addresses[1].clone()).unwrap(); + let addr_2 = SimAddr::new(addresses[2].clone()).unwrap(); + let addr_3 = SimAddr::new(addresses[3].clone()).unwrap(); let one = Box::new(MessageDeliveryEvent::new( addr_0.clone(), addr_1.clone(), @@ -1286,20 +1084,20 @@ mod tests { #[tokio::test] async fn test_read_config_from_yaml() { let yaml = r#" -edges: - - src: local!0 - dst: local!1 - metadata: - latency: 1 - - src: local!0 - dst: local!2 - metadata: - latency: 2 - - src: local!1 - dst: local!2 - metadata: - latency: 3 -"#; + edges: + - src: local!0 + dst: local!1 + metadata: + latency: 1 + - src: local!0 + dst: local!2 + metadata: + latency: 2 + - src: local!1 + dst: local!2 + metadata: + latency: 3 + "#; let config = NetworkConfig::from_yaml(yaml).unwrap(); assert_eq!(config.edges.len(), 3); assert_eq!( @@ -1331,74 +1129,9 @@ edges: assert_eq!(config.edges[2].metadata.latency, Duration::from_secs(3)); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_external_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let src_to_dst_msg = Serialized::serialize(&"hola".to_string()).unwrap(); - let src = random_abstract_addr(); - let dst = random_abstract_addr(); - let src_and_proxy = Some(AddressProxyPair { - address: src.clone(), - proxy: proxy_addr.clone(), - }); - let dst_and_proxy = AddressProxyPair { - address: dst.clone(), - proxy: proxy_addr.clone(), - }; - let forward_message = ProxyMessage::new(src_and_proxy, Some(dst_and_proxy), src_to_dst_msg); - let external_message = - MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - simnet_handle() - .unwrap() - .flush(Duration::from_millis(1000)) - .await - .unwrap(); - let records = simnet_handle().unwrap().close().await; - assert!(records.as_ref().unwrap().len() == 1); - let expected_record = SimulatorEventRecord { - summary: format!("Sending message from {} to {}", src, dst), - start_at: 0, - end_at: 1, - }; - assert_eq!(records.unwrap().first().unwrap(), &expected_record); - } - #[tokio::test] async fn test_sim_sleep() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); - - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - let _ = start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let start = SimClock.now(); assert_eq!(SimClock.millis_since_start(start), 0); @@ -1411,12 +1144,7 @@ edges: #[tokio::test] async fn test_torch_op() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let args_string = "1, 2".to_string(); let kwargs_string = "a=2".to_string(); diff --git a/hyperactor_multiprocess/src/ping_pong.rs b/hyperactor_multiprocess/src/ping_pong.rs index 676259cc..4c659153 100644 --- a/hyperactor_multiprocess/src/ping_pong.rs +++ b/hyperactor_multiprocess/src/ping_pong.rs @@ -14,9 +14,7 @@ mod tests { use hyperactor::ActorRef; use hyperactor::Mailbox; use hyperactor::channel::ChannelAddr; - use hyperactor::channel::ChannelTransport; use hyperactor::channel::sim; - use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::id; use hyperactor::reference::Index; @@ -36,16 +34,10 @@ mod tests { #[tokio::test] async fn test_sim_ping_pong() { let system_addr = "local!1".parse::().unwrap(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); - let system_sim_addr = SimAddr::new(system_addr.clone(), proxy_addr.clone()).unwrap(); + let system_sim_addr = SimAddr::new(system_addr.clone()).unwrap(); let server_handle = System::serve( ChannelAddr::Sim(system_sim_addr.clone()), Duration::from_secs(10), @@ -64,21 +56,14 @@ mod tests { let ping_actor_ref = spawn_proc_actor( 2, - proxy_addr.clone(), system_sim_addr.clone(), sys_mailbox.clone(), world_id.clone(), ) .await; - let pong_actor_ref = spawn_proc_actor( - 3, - proxy_addr.clone(), - system_sim_addr, - sys_mailbox.clone(), - world_id.clone(), - ) - .await; + let pong_actor_ref = + spawn_proc_actor(3, system_sim_addr, sys_mailbox.clone(), world_id.clone()).await; // Configure the simulation network. let simnet_config_yaml = r#" @@ -124,7 +109,6 @@ edges: async fn spawn_proc_actor( actor_index: Index, - proxy_addr: ChannelAddr, system_addr: SimAddr, sys_mailbox: Mailbox, world_id: WorldId, @@ -133,19 +117,11 @@ edges: .parse::() .unwrap(); - let proc_sim_addr = SimAddr::new(proc_addr.clone(), proxy_addr.clone()).unwrap(); + let proc_sim_addr = SimAddr::new(proc_addr.clone()).unwrap(); let proc_listen_addr = ChannelAddr::Sim(proc_sim_addr); let proc_id = world_id.proc_id(actor_index); let proc_to_system = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: proc_addr.clone(), - proxy: proxy_addr.clone(), - }, - system_addr.addr().clone(), - system_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(proc_addr.clone(), system_addr.addr().clone()).unwrap(), ); let bootstrap = ProcActor::bootstrap( proc_id, diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index fd477b94..c80160b7 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -38,7 +38,6 @@ use hyperactor::RefClient; use hyperactor::WorldId; use hyperactor::actor::Handler; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::clock::Clock; use hyperactor::clock::ClockKind; @@ -762,13 +761,9 @@ impl ReportingRouter { ChannelAddr::Sim( SimAddr::new_with_src( // source is the sender - AddressProxyPair { - address: sender_sim_addr.addr().clone(), - proxy: sender_sim_addr.proxy().clone(), - }, + sender_sim_addr.addr().clone(), // dest remains unchanged dest_sim_addr.addr().clone(), - dest_sim_addr.proxy().clone(), ) .unwrap(), ) @@ -2800,19 +2795,11 @@ mod tests { #[tokio::test] async fn test_update_sim_address() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + simnet::start(); let src_id = id!(proc[0].actor); - let src_addr = - ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap(), proxy.clone()).unwrap()); - let dst_addr = - ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap(), proxy.clone()).unwrap()); + let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap()); + let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap()); let (_, mut rx) = channel::serve::(src_addr.clone()) .await .unwrap(); @@ -2844,7 +2831,7 @@ mod tests { panic!("Expected sim address"); }; - assert_eq!(addr.src().clone().unwrap().address.to_string(), "unix!@src"); + assert_eq!(addr.src().clone().unwrap().to_string(), "unix!@src"); assert_eq!(addr.addr().to_string(), "unix!@dst"); } } diff --git a/monarch_extension/src/simulation_tools.rs b/monarch_extension/src/simulation_tools.rs index 6913967b..a73855e0 100644 --- a/monarch_extension/src/simulation_tools.rs +++ b/monarch_extension/src/simulation_tools.rs @@ -6,24 +6,16 @@ * LICENSE file in the root directory of this source tree. */ -use hyperactor::channel::ChannelAddr; -use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::SimClock; use hyperactor::simnet; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; #[pyfunction] #[pyo3(name = "start_event_loop")] pub fn start_simnet_event_loop(py: Python) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + simnet::start(); Ok(()) }) } diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index eb93c1c5..46d23e29 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -52,29 +52,17 @@ impl SimulatorClient { #[new] fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { signal_safe_block_on(py, async move { - let system_addr = system_addr - .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?; - - let ChannelAddr::Sim(system_sim_addr) = &system_addr else { - return Err(PyValueError::new_err(format!( - "bootstrap address should be a sim address: {}", - system_addr - ))); - }; - - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - system_sim_addr.proxy().clone(), - 1000, - ) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + simnet::start(); Ok(Self { inner: Arc::new(Mutex::new( - TensorEngineSimulator::new(system_addr) - .await - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + TensorEngineSimulator::new( + system_addr + .parse::() + .map_err(|err| PyValueError::new_err(err.to_string()))?, + ) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, )), world_size: world_size as usize, }) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index ef7522a2..9fd6f055 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -25,7 +25,7 @@ pub enum PyChannelTransport { MetaTls, Local, Unix, - // Sim(/*proxy address:*/ ChannelAddr), TODO kiuk@ add support + // Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support } #[pyclass( @@ -131,9 +131,7 @@ mod tests { #[test] fn test_channel_unsupported_transport() -> PyResult<()> { - let sim_addr = ChannelAddr::any(ChannelTransport::Sim(ChannelAddr::any( - ChannelTransport::Unix, - ))); + let sim_addr = ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let addr = PyChannelAddr { inner: sim_addr }; assert!(addr.get_port().is_err()); diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index 6ba3000d..40b01d11 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -18,7 +18,6 @@ use hyperactor::ActorRef; use hyperactor::ProcId; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; @@ -62,15 +61,7 @@ pub async fn spawn_controller( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "controller listen addr: {}, bootstrap addr: {}", @@ -121,15 +112,7 @@ pub async fn spawn_sim_worker( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "worker {} listen addr: {}, bootstrap addr: {}", diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 97c02b77..a9b9baf7 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -133,18 +133,9 @@ mod tests { #[tracing_test::traced_test] #[tokio::test] async fn test_spawn_and_kill_mesh() { - let proxy = format!("unix!@{}", random_str()); + simnet::start(); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.parse().unwrap(), - 1000, - ) - .unwrap(); - - let system_addr = format!("sim!unix!@system,{}", &proxy) - .parse::() - .unwrap(); + let system_addr = "sim!unix!@system".parse::().unwrap(); let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) .await .unwrap(); diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index 4d662e2d..61114f7c 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -809,12 +809,7 @@ mod tests { #[tokio::test] async fn worker_reduce() -> Result<()> { - simnet::start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let proc = Proc::local(); //let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller")?; let client = proc.attach("client")?; diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index 8c304efd..2e999190 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -58,9 +58,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def sim_mesh( - n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None -) -> List[DeviceMesh]: +def sim_mesh(n_meshes: int, hosts: int, gpus_per_host: int) -> List[DeviceMesh]: """ Creates a single simulated device mesh with the given number of per host. @@ -185,7 +183,6 @@ def __init__( Bootstraps a SimMesh. Args: num_meshes: int - number of meshes to create. - proxy_addr: Option[str] - the proxy address of the simulation process mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active. """ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later @@ -196,13 +193,9 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = f"unix!@{_random_id()}-proxy" - self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr = ( - f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" - ) + self.bootstrap_addr: str = "sim!unix!@system" + self.client_listen_addr = "sim!unix!@client" + self.client_bootstrap_addr = "sim!unix!@client,unix!@system" self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): diff --git a/python/tests/test_sim_backend.py b/python/tests/test_sim_backend.py index 1ea89cd5..489bc736 100644 --- a/python/tests/test_sim_backend.py +++ b/python/tests/test_sim_backend.py @@ -24,11 +24,8 @@ def local_sim_mesh( # TODO: support multiple gpus in a mesh. gpu_per_host: int = 1, activate: bool = True, - proxy_addr: Optional[str] = None, ) -> Generator[DeviceMesh, None, None]: - dms = sim_mesh( - n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host, proxy_addr=proxy_addr - ) + dms = sim_mesh(n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host) dm = dms[0] try: if activate: From 1f5fe31e4362d97d8ef1bf456f0f803673654370 Mon Sep 17 00:00:00 2001 From: Thomas Wang Date: Thu, 10 Jul 2025 09:59:50 -0700 Subject: [PATCH 3/3] local_proc_mesh with sim channels (#475) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/475 We want to instantiate local_proc_meshes that use ChannelTransport::SIm instead of ChannelTranport::Local so that the simulator can intercept and control the delivery of messages To preserve to `allocate()` interface so that we can reuse existing test generation macros we will create a wrapper class for this around `LocalAlloc` Differential Revision: D77941640 --- hyperactor_mesh/src/actor_mesh.rs | 6 + hyperactor_mesh/src/alloc.rs | 1 + hyperactor_mesh/src/alloc/local.rs | 10 +- hyperactor_mesh/src/alloc/sim.rs | 122 ++++++++++++++++++ monarch_hyperactor/src/alloc.rs | 49 +++++++ python/monarch/__init__.py | 4 + .../monarch_hyperactor/alloc.pyi | 20 +++ python/monarch/_src/actor/allocator.py | 23 ++++ python/monarch/_src/actor/proc_mesh.py | 29 ++++- python/monarch/actor/__init__.py | 8 +- 10 files changed, 268 insertions(+), 4 deletions(-) create mode 100644 hyperactor_mesh/src/alloc/sim.rs diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 69bb7ca1..464f8e82 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -987,4 +987,10 @@ mod tests { buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap() ))); } + + mod sim { + use crate::alloc::sim::SimAllocator; + + actor_mesh_test_suite!(SimAllocator::new_and_start_simnet()); + } } diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index 9c36f525..cfd758b4 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -13,6 +13,7 @@ pub mod local; pub(crate) mod logtailer; pub mod process; pub mod remoteprocess; +pub mod sim; use std::collections::HashMap; use std::fmt; diff --git a/hyperactor_mesh/src/alloc/local.rs b/hyperactor_mesh/src/alloc/local.rs index 5950eaa3..ceb4915f 100644 --- a/hyperactor_mesh/src/alloc/local.rs +++ b/hyperactor_mesh/src/alloc/local.rs @@ -75,10 +75,15 @@ pub struct LocalAlloc { todo_rx: mpsc::UnboundedReceiver, stopped: bool, failed: bool, + transport: ChannelTransport, } impl LocalAlloc { fn new(spec: AllocSpec) -> Self { + Self::new_with_transport(spec, ChannelTransport::Local) + } + + pub(crate) fn new_with_transport(spec: AllocSpec, transport: ChannelTransport) -> Self { let name = ShortUuid::generate(); let (todo_tx, todo_rx) = mpsc::unbounded_channel(); for rank in 0..spec.shape.slice().len() { @@ -94,6 +99,7 @@ impl LocalAlloc { todo_rx, stopped: false, failed: false, + transport, } } @@ -123,7 +129,7 @@ impl LocalAlloc { &self.name } - fn size(&self) -> usize { + pub(crate) fn size(&self) -> usize { self.spec.shape.slice().len() } } @@ -249,7 +255,7 @@ impl Alloc for LocalAlloc { } fn transport(&self) -> ChannelTransport { - ChannelTransport::Local + self.transport.clone() } async fn stop(&mut self) -> Result<(), AllocatorError> { diff --git a/hyperactor_mesh/src/alloc/sim.rs b/hyperactor_mesh/src/alloc/sim.rs new file mode 100644 index 00000000..f680d5a2 --- /dev/null +++ b/hyperactor_mesh/src/alloc/sim.rs @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Support for allocating procs in the local process with simulated channels. + +#![allow(dead_code)] // until it is used outside of testing + +use async_trait::async_trait; +use hyperactor::WorldId; +use hyperactor::channel::ChannelAddr; +use hyperactor::channel::ChannelTransport; +use hyperactor::mailbox::MailboxServerHandle; +use hyperactor::proc::Proc; +use ndslice::Shape; + +use super::ProcStopReason; +use crate::alloc::Alloc; +use crate::alloc::AllocSpec; +use crate::alloc::Allocator; +use crate::alloc::AllocatorError; +use crate::alloc::LocalAlloc; +use crate::alloc::ProcState; +use crate::shortuuid::ShortUuid; + +/// An allocator that runs procs in the local process with network traffic going through simulated channels. +/// Other than transport, the underlying implementation is an inner LocalAlloc. +pub struct SimAllocator; + +#[async_trait] +impl Allocator for SimAllocator { + type Alloc = SimAlloc; + + async fn allocate(&mut self, spec: AllocSpec) -> Result { + Ok(SimAlloc::new(spec)) + } +} + +impl SimAllocator { + #[cfg(test)] + pub(crate) fn new_and_start_simnet() -> Self { + hyperactor::simnet::start(); + Self + } +} + +struct SimProc { + proc: Proc, + addr: ChannelAddr, + handle: MailboxServerHandle, +} + +/// A simulated allocation. It is a collection of procs that are running in the local process. +pub struct SimAlloc { + inner: LocalAlloc, +} + +impl SimAlloc { + fn new(spec: AllocSpec) -> Self { + Self { + inner: LocalAlloc::new_with_transport( + spec, + ChannelTransport::Sim(Box::new(ChannelTransport::Unix)), + ), + } + } + /// A chaos monkey that can be used to stop procs at random. + pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static { + self.inner.chaos_monkey() + } + + /// A function to shut down the alloc for testing purposes. + pub(crate) fn stopper(&self) -> impl Fn() + 'static { + self.inner.stopper() + } + + pub(crate) fn name(&self) -> &ShortUuid { + self.inner.name() + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[async_trait] +impl Alloc for SimAlloc { + async fn next(&mut self) -> Option { + self.inner.next().await + } + + fn shape(&self) -> &Shape { + self.inner.shape() + } + + fn world_id(&self) -> &WorldId { + self.inner.world_id() + } + + fn transport(&self) -> ChannelTransport { + self.inner.transport() + } + + async fn stop(&mut self) -> Result<(), AllocatorError> { + self.inner.stop().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_allocator_basic() { + hyperactor::simnet::start(); + crate::alloc::testing::test_allocator_basic(SimAllocator).await; + } +} diff --git a/monarch_hyperactor/src/alloc.rs b/monarch_hyperactor/src/alloc.rs index b95b229c..558e20bf 100644 --- a/monarch_hyperactor/src/alloc.rs +++ b/monarch_hyperactor/src/alloc.rs @@ -28,6 +28,7 @@ use hyperactor_mesh::alloc::ProcessAllocator; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAlloc; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocHost; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocInitializer; +use hyperactor_mesh::alloc::sim::SimAllocator; use hyperactor_mesh::shape::Shape; use ndslice::Slice; use pyo3::exceptions::PyRuntimeError; @@ -222,6 +223,53 @@ impl PyLocalAllocator { } } +#[pyclass( + name = "SimAllocatorBase", + module = "monarch._rust_bindings.monarch_hyperactor.alloc", + subclass +)] +pub struct PySimAllocator; + +#[pymethods] +impl PySimAllocator { + #[new] + fn new() -> Self { + PySimAllocator {} + } + + fn allocate_nonblocking<'py>( + &self, + py: Python<'py>, + spec: &PyAllocSpec, + ) -> PyResult> { + // We could use Bound here, and acquire the GIL inside of `future_into_py`, but + // it is rather awkward with the current APIs, and we can anyway support Arc/Mutex + // pretty easily. + let spec = spec.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + SimAllocator + .allocate(spec) + .await + .map(|inner| PyAlloc::new(Box::new(inner))) + .map_err(|e| PyRuntimeError::new_err(format!("{}", e))) + }) + } + + fn allocate_blocking<'py>(&self, py: Python<'py>, spec: &PyAllocSpec) -> PyResult { + // We could use Bound here, and acquire the GIL inside of + // `signal_safe_block_on`, but it is rather awkward with the current + // APIs, and we can anyway support Arc/Mutex pretty easily. + let spec = spec.inner.clone(); + signal_safe_block_on(py, async move { + SimAllocator + .allocate(spec) + .await + .map(|inner| PyAlloc::new(Box::new(inner))) + .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) + })? + } +} + #[pyclass( name = "ProcessAllocatorBase", module = "monarch._rust_bindings.monarch_hyperactor.alloc", @@ -474,6 +522,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) diff --git a/python/monarch/__init__.py b/python/monarch/__init__.py index 7a0b1330..92d1795b 100644 --- a/python/monarch/__init__.py +++ b/python/monarch/__init__.py @@ -113,6 +113,8 @@ "timer": ("monarch.timer", "timer"), "ProcessAllocator": ("monarch._src.actor.allocator", "ProcessAllocator"), "LocalAllocator": ("monarch._src.actor.allocator", "LocalAllocator"), + "SimAllocator": ("monarch._src_actor.allocator", "SimAllocator"), + "ActorFuture": ("monarch.future", "ActorFuture"), "builtins": ("monarch.builtins", "builtins"), } @@ -181,6 +183,8 @@ def __getattr__(name): "timer", "ProcessAllocator", "LocalAllocator", + "SimAllocator", + "ActorFuture", "builtins", ] assert sorted(__all__) == sorted(_public_api) diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi index 06b1e33d..a58cb667 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi @@ -100,6 +100,26 @@ class LocalAllocatorBase: """ ... +class SimAllocatorBase: + async def allocate_nonblocking(self, spec: AllocSpec) -> Alloc: + """ + Allocate a process according to the provided spec. + + Arguments: + - `spec`: The spec to allocate according to. + """ + ... + + def allocate_blocking(self, spec: AllocSpec) -> Alloc: + """ + Allocate a process according to the provided spec, blocking until an + alloc is returned. + + Arguments: + - `spec`: The spec to allocate according to. + """ + ... + class RemoteAllocatorBase: def __new__( cls, diff --git a/python/monarch/_src/actor/allocator.py b/python/monarch/_src/actor/allocator.py index 2df6dc98..e27c83fe 100644 --- a/python/monarch/_src/actor/allocator.py +++ b/python/monarch/_src/actor/allocator.py @@ -16,6 +16,7 @@ LocalAllocatorBase, ProcessAllocatorBase, RemoteAllocatorBase, + SimAllocatorBase, ) from monarch._src.actor.future import Future @@ -69,6 +70,28 @@ def allocate(self, spec: AllocSpec) -> Future[Alloc]: ) +@final +class SimAllocator(SimAllocatorBase): + """ + An allocator that allocates by spawning actors into the current process using simulated channels for transport + """ + + def allocate(self, spec: AllocSpec) -> Future[Alloc]: + """ + Allocate a process according to the provided spec. + + Arguments: + - `spec`: The spec to allocate according to. + + Returns: + - A future that will be fulfilled when the requested allocation is fulfilled. + """ + return Future( + lambda: self.allocate_nonblocking(spec), + lambda: self.allocate_blocking(spec), + ) + + class RemoteAllocInitializer(abc.ABC): """Subclass-able Python interface for `hyperactor_mesh::alloc::remoteprocess:RemoteProcessAllocInitializer`. diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..c14e2499 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -35,7 +35,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice from monarch._src.actor.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef -from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator +from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor @@ -312,6 +312,33 @@ def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[Pro ) +async def sim_proc_mesh_nonblocking( + *, gpus: Optional[int] = None, hosts: int = 1 +) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = SimAllocator() + alloc = await allocator.allocate(spec) + return await ProcMesh.from_alloc(alloc) + + +def sim_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = SimAllocator() + alloc = allocator.allocate(spec).get() + return ProcMesh.from_alloc(alloc).get() + + +def sim_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: + return Future( + lambda: sim_proc_mesh_nonblocking(gpus=gpus, hosts=hosts), + lambda: sim_proc_mesh_blocking(gpus=gpus, hosts=hosts), + ) + + _BOOTSTRAP_MAIN = "monarch._src.actor.bootstrap_main" diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..78fbcd1d 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,12 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + local_proc_mesh, + proc_mesh, + ProcMesh, + sim_proc_mesh, +) __all__ = [ "Accumulator", @@ -41,5 +46,6 @@ "proc_mesh", "ProcMesh", "send", + "sim_proc_mesh", "ValueMesh", ]