From 219ae17cfffa73cebcd854684ab5087119c55e1f Mon Sep 17 00:00:00 2001 From: Thomas Wang Date: Thu, 10 Jul 2025 10:00:24 -0700 Subject: [PATCH] No more operational messages (#473) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/473 Operational messages were used in order to signal to the simulator that it should perform certain actions like growing or shrinking the mesh. This was needed since the python and rust were running in separate processes, and messages were needed to communicate between the two, but now everything is on the same process so we can do this in memory. Differential Revision: D77941643 --- hyperactor/src/simnet.rs | 166 ++---------------- monarch_extension/src/simulator_client.rs | 150 ++++++++-------- monarch_simulator/Cargo.toml | 7 +- monarch_simulator/src/bootstrap.rs | 66 ------- monarch_simulator/src/main.rs | 54 ------ monarch_simulator/src/simulator.rs | 10 +- .../monarch_extension/simulator_client.pyi | 13 +- python/monarch/sim_mesh.py | 17 +- 8 files changed, 99 insertions(+), 384 deletions(-) delete mode 100644 monarch_simulator/src/main.rs diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 0a87a14f..9ae27d2d 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -33,12 +33,9 @@ use serde::Serialize; use serde::Serializer; use serde_with::serde_as; use tokio::sync::Mutex; -use tokio::sync::SetError; use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; @@ -49,7 +46,6 @@ use crate::ActorId; use crate::Mailbox; use crate::Named; use crate::OncePortRef; -use crate::WorldId; use crate::channel; use crate::channel::ChannelAddr; use crate::channel::Rx; @@ -313,18 +309,6 @@ pub enum SimNetError { #[error("proxy not available: {0}")] ProxyNotAvailable(String), - /// Unable to send message to the simulator. - #[error(transparent)] - OperationalMessageSendError(#[from] SendError), - - /// Setting the operational message sender which is already set. - #[error(transparent)] - OperationalMessageSenderSetError(#[from] SetError>), - - /// Missing OperationalMessageReceiver. - #[error("missing operational message receiver")] - MissingOperationalMessageReceiver, - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -361,8 +345,6 @@ pub struct SimNetHandle { /// Handle to a running proxy server that forwards external messages /// into the simnet. proxy_handle: ProxyHandle, - /// A sender to forward simulator operational messages. - operational_message_tx: UnboundedSender, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -489,82 +471,6 @@ impl SimNetHandle { pub(crate) type Topology = DashMap; -/// The message to spawn a simulated mesh. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct SpawnMesh { - /// The system address. - pub system_addr: ChannelAddr, - /// The controller actor ID. - pub controller_actor_id: ActorId, - /// The worker world. - pub worker_world: WorldId, -} - -impl SpawnMesh { - /// Creates a new SpawnMesh. - pub fn new( - system_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world: WorldId, - ) -> Self { - Self { - system_addr, - controller_actor_id, - worker_world, - } - } -} - -/// An OperationalMessage is a message to control the simulator to do tasks such as -/// spawning or killing actors. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub enum OperationalMessage { - /// Kill the world with given world_id. - KillWorld(String), - /// Spawn actors in a mesh. - SpawnMesh(SpawnMesh), - /// Update training script state. - SetTrainingScriptState(TrainingScriptState), -} - -/// Message Event that can be sent to the simulator. -#[derive(Debug)] -pub struct SimOperation { - /// Sender to send OperationalMessage to the simulator. - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, -} - -impl SimOperation { - /// Creates a new SimOperation. - pub fn new( - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, - ) -> Self { - Self { - operational_message_tx, - operational_message, - } - } -} - -#[async_trait] -impl Event for SimOperation { - async fn handle(&self) -> Result<(), SimNetError> { - self.operational_message_tx - .send(self.operational_message.clone())?; - Ok(()) - } - - fn duration_ms(&self) -> u64 { - 0 - } - - fn summary(&self) -> String { - format!("SimOperation: {:?}", self.operational_message) - } -} - /// A ProxyMessage is a message that SimNet proxy receives. /// The message may requests the SimNet to send the payload in the message field from /// src to dst if addr field exists. @@ -658,7 +564,6 @@ impl ProxyHandle { proxy_addr: ChannelAddr, event_tx: UnboundedSender<(Box, bool, Option)>, pending_event_count: Arc, - operational_message_tx: UnboundedSender, ) -> anyhow::Result { let (addr, mut rx) = channel::serve::(proxy_addr).await?; tracing::info!("SimNet serving external traffic on {}", &addr); @@ -672,26 +577,18 @@ impl ProxyHandle { #[allow(clippy::disallowed_methods)] if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - let event: Box = match proxy_message.dest_addr { - Some(dest_addr) => Box::new(MessageDeliveryEvent::new( + if let Some(dest_addr) = proxy_message.dest_addr { + let event = Box::new(MessageDeliveryEvent::new( proxy_message.sender_addr, dest_addr, proxy_message.data, - )), - None => { - let operational_message: OperationalMessage = - proxy_message.data.deserialized().unwrap(); - Box::new(SimOperation::new( - operational_message_tx.clone(), - operational_message, - )) + )); + if let Err(e) = event_tx.send((event, true, None)) { + tracing::error!("error sending message to simnet: {:?}", e); + } else { + pending_event_count + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); } - }; - - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); } } if stop_signal.load(Ordering::SeqCst) { @@ -729,7 +626,7 @@ pub fn start( private_addr: ChannelAddr, proxy_addr: ChannelAddr, max_duration_ms: u64, -) -> anyhow::Result> { +) -> anyhow::Result<()> { // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); address_book.insert(private_addr.clone()); @@ -775,14 +672,11 @@ pub fn start( .await }) })); - let (operational_message_tx, operational_message_rx) = - mpsc::unbounded_channel::(); let proxy_handle = block_on(ProxyHandle::start( proxy_addr, event_tx.clone(), pending_event_count.clone(), - operational_message_tx.clone(), )) .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; @@ -792,12 +686,11 @@ pub fn start( config, pending_event_count, proxy_handle, - operational_message_tx, training_script_state_tx, stop_signal, }); - Ok(operational_message_rx) + Ok(()) } impl SimNet { @@ -1488,45 +1381,6 @@ edges: assert_eq!(records.unwrap().first().unwrap(), &expected_record); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_operational_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let mut operational_message_rx = start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let spawn_mesh = SpawnMesh { - system_addr: "unix!@system".parse().unwrap(), - controller_actor_id: id!(controller_world[0].actor), - worker_world: id!(worker_world), - }; - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh.clone()); - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let external_message = MessageEnvelope::new_unknown(port_id, serialized_proxy_message); - - // Send the message to the simnet. - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - let received_operational_message = operational_message_rx.recv().await.unwrap(); - - // Check the received message. - assert_eq!(received_operational_message, operational_message); - } - #[tokio::test] async fn test_sim_sleep() { start( diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index 85657d40..eb93c1c5 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -8,24 +8,21 @@ #![cfg(feature = "tensor_engine")] +use std::sync::Arc; + use anyhow::anyhow; -use hyperactor::PortId; -use hyperactor::attrs::Attrs; +use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::Tx; -use hyperactor::channel::dial; -use hyperactor::data::Serialized; -use hyperactor::id; -use hyperactor::mailbox::MessageEnvelope; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::ProxyMessage; -use hyperactor::simnet::SpawnMesh; +use hyperactor::channel::ChannelTransport; +use hyperactor::simnet; use hyperactor::simnet::TrainingScriptState; +use hyperactor::simnet::simnet_handle; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_simulator_lib::bootstrap::bootstrap; +use monarch_simulator_lib::simulator::TensorEngineSimulator; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use tokio::sync::Mutex; /// A wrapper around [ndslice::Slice] to expose it to python. /// It is a compact representation of indices into the flat @@ -39,102 +36,99 @@ use pyo3::prelude::*; )] #[derive(Clone)] pub(crate) struct SimulatorClient { - proxy_addr: ChannelAddr, -} - -fn wrap_operational_message(operational_message: OperationalMessage) -> MessageEnvelope { - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let sender_id = id!(simulator_client[0].sender_actor); - // a dummy port ID. We are delivering message with low level mailbox. - // The port ID is not used. - let port_id = PortId(id!(simulator[0].actor), 0); - MessageEnvelope::new(sender_id, port_id, serialized_proxy_message, Attrs::new()) -} - -#[pyfunction] -fn bootstrap_simulator_backend( - py: Python, - system_addr: String, - proxy_addr: String, - world_size: i32, -) -> PyResult<()> { - signal_safe_block_on(py, async move { - match bootstrap( - system_addr.parse().unwrap(), - proxy_addr.parse().unwrap(), - world_size as usize, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => Err(PyRuntimeError::new_err(err.to_string())), - } - })? + inner: Arc>, + world_size: usize, } -fn set_training_script_state(state: TrainingScriptState, proxy_addr: ChannelAddr) -> PyResult<()> { - let operational_message = OperationalMessage::SetTrainingScriptState(state); - let external_message = wrap_operational_message(operational_message); - let tx = dial(proxy_addr).map_err(|err| anyhow!(err))?; - tx.post(external_message); +fn set_training_script_state(state: TrainingScriptState) -> PyResult<()> { + simnet_handle() + .map_err(|e| anyhow!(e))? + .set_training_script_state(state); Ok(()) } #[pymethods] impl SimulatorClient { #[new] - fn new(proxy_addr: &str) -> PyResult { - Ok(Self { - proxy_addr: proxy_addr + fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { + signal_safe_block_on(py, async move { + let system_addr = system_addr .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?, - }) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let ChannelAddr::Sim(system_sim_addr) = &system_addr else { + return Err(PyValueError::new_err(format!( + "bootstrap address should be a sim address: {}", + system_addr + ))); + }; + + simnet::start( + ChannelAddr::any(ChannelTransport::Unix), + system_sim_addr.proxy().clone(), + 1000, + ) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + + Ok(Self { + inner: Arc::new(Mutex::new( + TensorEngineSimulator::new(system_addr) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + )), + world_size: world_size as usize, + }) + })? } - fn kill_world(&self, world_name: &str) -> PyResult<()> { - let operational_message = OperationalMessage::KillWorld(world_name.to_string()); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + fn kill_world(&self, py: Python, world_name: &str) -> PyResult<()> { + let simulator = self.inner.clone(); + let world_name = world_name.to_string(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .kill_world(&world_name) + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn spawn_mesh( &self, + py: Python, system_addr: &str, controller_actor_id: &str, worker_world: &str, ) -> PyResult<()> { - let spawn_mesh = SpawnMesh::new( - system_addr.parse().unwrap(), - controller_actor_id.parse().unwrap(), - worker_world.parse().unwrap(), - ); - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + let simulator = self.inner.clone(); + let world_size = self.world_size; + let system_addr = system_addr.parse::().unwrap(); + let worker_world = worker_world.parse::().unwrap(); + let controller_actor_id = controller_actor_id.parse().unwrap(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) + .await + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn set_training_script_state_running(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Running, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Running) } fn set_training_script_state_waiting(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Waiting, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Waiting) } } pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { simulator_client_mod.add_class::()?; - let f = wrap_pyfunction!(bootstrap_simulator_backend, simulator_client_mod)?; - f.setattr( - "__module__", - "monarch._rust_bindings.monarch_extension.simulator_client", - )?; - simulator_client_mod.add_function(f)?; Ok(()) } diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml index 863b5384..e28372ed 100644 --- a/monarch_simulator/Cargo.toml +++ b/monarch_simulator/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/monarch_simulator:[monarch_simulator,monarch_simulator_lib] +# @generated by autocargo from //monarch/monarch_simulator:monarch_simulator_lib [package] name = "monarch_simulator_lib" @@ -7,14 +7,9 @@ authors = ["Meta"] edition = "2021" license = "BSD-3-Clause" -[[bin]] -name = "monarch_simulator" -path = "src/main.rs" - [dependencies] anyhow = "1.0.98" async-trait = "0.1.86" -clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller" } dashmap = { version = "5.5.3", features = ["rayon", "serde"] } futures = { version = "0.3.30", features = ["async-await", "compat"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index e946fd63..6ba3000d 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -20,23 +20,15 @@ use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; -use hyperactor::simnet; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::SpawnMesh; -use hyperactor::simnet::simnet_handle; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; use hyperactor_multiprocess::proc_actor::spawn; use hyperactor_multiprocess::system::ServerHandle; use hyperactor_multiprocess::system_actor::ProcLifecycleMode; use monarch_messages::worker::Factory; -use tokio::sync::Mutex; -use tokio::sync::mpsc::UnboundedReceiver; -use tokio::task::JoinHandle; use torch_sys::Layout; use torch_sys::ScalarType; -use crate::simulator::Simulator; use crate::worker::Fabric; use crate::worker::MockWorkerParams; use crate::worker::WorkerActor; @@ -184,61 +176,3 @@ pub async fn spawn_sim_worker( .await?; Ok(bootstrap.proc_actor) } - -/// Bootstrap the simulation. Spawns the system, controllers, and workers. -/// Args: -/// system_addr: The address of the system actor. -pub async fn bootstrap( - system_addr: ChannelAddr, - proxy_addr: ChannelAddr, - world_size: usize, -) -> Result> { - // TODO: enable supervision events. - let mut operational_message_rx = simnet::start(system_addr.clone(), proxy_addr, 1000)?; - let simulator = Arc::new(Mutex::new(Simulator::new(system_addr).await?)); - let operational_listener_handle = { - let simulator = simulator.clone(); - tokio::spawn(async move { - handle_operational_message(&mut operational_message_rx, simulator, world_size).await - }) - }; - - Ok(operational_listener_handle) -} - -async fn handle_operational_message( - operational_message_rx: &mut UnboundedReceiver, - simulator: Arc>, - world_size: usize, -) { - while let Some(msg) = operational_message_rx.recv().await { - tracing::info!("received operational message: {:?}", msg); - match msg { - OperationalMessage::SpawnMesh(SpawnMesh { - system_addr, - controller_actor_id, - worker_world, - }) => { - if let Err(e) = simulator - .lock() - .await - .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) - .await - { - tracing::error!("failed to spawn mesh: {:?}", e); - } - } - OperationalMessage::KillWorld(world_id) => { - if let Err(e) = simulator.lock().await.kill_world(&world_id) { - tracing::error!("failed to kill world: {:?}", e); - } - } - OperationalMessage::SetTrainingScriptState(state) => match simnet_handle() { - Ok(handle) => handle.set_training_script_state(state), - Err(e) => { - tracing::error!("failed to set training script state: {:?}", e); - } - }, - } - } -} diff --git a/monarch_simulator/src/main.rs b/monarch_simulator/src/main.rs deleted file mode 100644 index 12169d91..00000000 --- a/monarch_simulator/src/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! A binary to launch the simulated Monarch controller along with necessary environment. -use std::process::ExitCode; - -use anyhow::Result; -use clap::Parser; -use hyperactor::channel::ChannelAddr; -use monarch_simulator_lib::bootstrap::bootstrap; - -#[derive(Debug, Parser)] -struct Args { - #[arg(short, long)] - system_addr: ChannelAddr, - #[arg(short, long)] - proxy_addr: ChannelAddr, -} - -const TITLE: &str = r#" -****************************************************** -* * -* ____ ___ __ __ _ _ _ _ _____ ___ ____ * -*/ ___|_ _| \/ | | | | | / \|_ _/ _ \| _ \ * -*\___ \| || |\/| | | | | | / _ \ | || | | | |_) |* -* ___) | || | | | |_| | |___ / ___ \| || |_| | _ < * -*|____/___|_| |_|\___/|_____/_/ \_\_| \___/|_| \_\* -* * -****************************************************** -"#; - -#[tokio::main] -async fn main() -> Result { - eprintln!("{}", TITLE); - hyperactor::initialize_with_current_runtime(); - let args = Args::parse(); - - let system_addr = args.system_addr.clone(); - let proxy_addr = args.proxy_addr.clone(); - tracing::info!("starting Monarch simulation"); - - let operational_listener_handle = bootstrap(system_addr, proxy_addr, 1).await?; - - operational_listener_handle - .await - .expect("simulator exited unexpectedly"); - - Ok(ExitCode::SUCCESS) -} diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 10fa9fd7..97c02b77 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -26,13 +26,13 @@ use crate::bootstrap::spawn_system; /// The simulator manages all of the meshes and the system handle. #[derive(Debug)] -pub struct Simulator { +pub struct TensorEngineSimulator { /// A map from world name to actor handles in that world. worlds: HashMap>>, system_handle: ServerHandle, } -impl Simulator { +impl TensorEngineSimulator { pub async fn new(system_addr: ChannelAddr) -> Result { Ok(Self { worlds: HashMap::new(), @@ -92,7 +92,7 @@ impl Simulator { /// IntoFuture allows users to await the handle. The future resolves when /// the simulator itself has all of the actor handles stopped. -impl IntoFuture for Simulator { +impl IntoFuture for TensorEngineSimulator { type Output = (); type IntoFuture = BoxFuture<'static, Self::Output>; @@ -145,7 +145,9 @@ mod tests { let system_addr = format!("sim!unix!@system,{}", &proxy) .parse::() .unwrap(); - let mut simulator = super::Simulator::new(system_addr.clone()).await.unwrap(); + let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) + .await + .unwrap(); let mut controller_actor_ids = vec![]; let mut worker_actor_ids = vec![]; let n_meshes = 2; diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi index 9cea9d93..6b06c274 100644 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi @@ -13,10 +13,11 @@ class SimulatorClient: It is a client to communicate with the simulator service. Arguments: - - `proxy_addr`: Address of the simulator's proxy server. + - `system_addr`: Address of the system. + - `world_size`: Number of workers in a given mesh. """ - def __init__(self, proxy_addr: str) -> None: ... + def __init__(self, system_addr: str, world_size: int) -> None: ... def kill_world(self, world_name: str) -> None: """ Kill the world with the given name. @@ -51,11 +52,3 @@ class SimulatorClient: backend to resolve a future """ ... - -def bootstrap_simulator_backend( - system_addr: str, proxy_addr: str, world_size: int -) -> None: - """ - Bootstrap the simulator backend on the current process - """ - ... diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index b91361bf..8c304efd 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -31,7 +31,6 @@ ) from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension - bootstrap_simulator_backend, SimulatorClient, ) @@ -76,7 +75,6 @@ def sim_mesh( bootstrap: Bootstrap = Bootstrap( n_meshes, mesh_world_state, - proxy_addr=proxy_addr, world_size=hosts * gpus_per_host, ) @@ -181,7 +179,6 @@ def __init__( self, num_meshes: int, mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - proxy_addr: Optional[str] = None, world_size: int = 1, ) -> None: """ @@ -199,17 +196,15 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy" + proxy_addr = f"unix!@{_random_id()}-proxy" self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr: str = ( + self.client_listen_addr = f"sim!unix!@client,{client_proxy_addr}" + self.client_bootstrap_addr = ( f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" ) - bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size) - self._simulator_client = SimulatorClient(proxy_addr) + self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): mesh_name: str = f"mesh_{i}" controller_world: str = f"{mesh_name}_controller" @@ -235,7 +230,9 @@ def spawn_mesh(self, mesh_world: MeshWorld) -> None: worker_world, controller_id = mesh_world controller_world = controller_id.world_name self._simulator_client.spawn_mesh( - self.bootstrap_addr, f"{controller_world}[0].root", worker_world + self.bootstrap_addr, + f"{controller_world}[0].root", + worker_world, )