diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 0a87a14f..9ae27d2d 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -33,12 +33,9 @@ use serde::Serialize; use serde::Serializer; use serde_with::serde_as; use tokio::sync::Mutex; -use tokio::sync::SetError; use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; @@ -49,7 +46,6 @@ use crate::ActorId; use crate::Mailbox; use crate::Named; use crate::OncePortRef; -use crate::WorldId; use crate::channel; use crate::channel::ChannelAddr; use crate::channel::Rx; @@ -313,18 +309,6 @@ pub enum SimNetError { #[error("proxy not available: {0}")] ProxyNotAvailable(String), - /// Unable to send message to the simulator. - #[error(transparent)] - OperationalMessageSendError(#[from] SendError), - - /// Setting the operational message sender which is already set. - #[error(transparent)] - OperationalMessageSenderSetError(#[from] SetError>), - - /// Missing OperationalMessageReceiver. - #[error("missing operational message receiver")] - MissingOperationalMessageReceiver, - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -361,8 +345,6 @@ pub struct SimNetHandle { /// Handle to a running proxy server that forwards external messages /// into the simnet. proxy_handle: ProxyHandle, - /// A sender to forward simulator operational messages. - operational_message_tx: UnboundedSender, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -489,82 +471,6 @@ impl SimNetHandle { pub(crate) type Topology = DashMap; -/// The message to spawn a simulated mesh. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct SpawnMesh { - /// The system address. - pub system_addr: ChannelAddr, - /// The controller actor ID. - pub controller_actor_id: ActorId, - /// The worker world. - pub worker_world: WorldId, -} - -impl SpawnMesh { - /// Creates a new SpawnMesh. - pub fn new( - system_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world: WorldId, - ) -> Self { - Self { - system_addr, - controller_actor_id, - worker_world, - } - } -} - -/// An OperationalMessage is a message to control the simulator to do tasks such as -/// spawning or killing actors. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub enum OperationalMessage { - /// Kill the world with given world_id. - KillWorld(String), - /// Spawn actors in a mesh. - SpawnMesh(SpawnMesh), - /// Update training script state. - SetTrainingScriptState(TrainingScriptState), -} - -/// Message Event that can be sent to the simulator. -#[derive(Debug)] -pub struct SimOperation { - /// Sender to send OperationalMessage to the simulator. - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, -} - -impl SimOperation { - /// Creates a new SimOperation. - pub fn new( - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, - ) -> Self { - Self { - operational_message_tx, - operational_message, - } - } -} - -#[async_trait] -impl Event for SimOperation { - async fn handle(&self) -> Result<(), SimNetError> { - self.operational_message_tx - .send(self.operational_message.clone())?; - Ok(()) - } - - fn duration_ms(&self) -> u64 { - 0 - } - - fn summary(&self) -> String { - format!("SimOperation: {:?}", self.operational_message) - } -} - /// A ProxyMessage is a message that SimNet proxy receives. /// The message may requests the SimNet to send the payload in the message field from /// src to dst if addr field exists. @@ -658,7 +564,6 @@ impl ProxyHandle { proxy_addr: ChannelAddr, event_tx: UnboundedSender<(Box, bool, Option)>, pending_event_count: Arc, - operational_message_tx: UnboundedSender, ) -> anyhow::Result { let (addr, mut rx) = channel::serve::(proxy_addr).await?; tracing::info!("SimNet serving external traffic on {}", &addr); @@ -672,26 +577,18 @@ impl ProxyHandle { #[allow(clippy::disallowed_methods)] if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - let event: Box = match proxy_message.dest_addr { - Some(dest_addr) => Box::new(MessageDeliveryEvent::new( + if let Some(dest_addr) = proxy_message.dest_addr { + let event = Box::new(MessageDeliveryEvent::new( proxy_message.sender_addr, dest_addr, proxy_message.data, - )), - None => { - let operational_message: OperationalMessage = - proxy_message.data.deserialized().unwrap(); - Box::new(SimOperation::new( - operational_message_tx.clone(), - operational_message, - )) + )); + if let Err(e) = event_tx.send((event, true, None)) { + tracing::error!("error sending message to simnet: {:?}", e); + } else { + pending_event_count + .fetch_add(1, std::sync::atomic::Ordering::SeqCst); } - }; - - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); } } if stop_signal.load(Ordering::SeqCst) { @@ -729,7 +626,7 @@ pub fn start( private_addr: ChannelAddr, proxy_addr: ChannelAddr, max_duration_ms: u64, -) -> anyhow::Result> { +) -> anyhow::Result<()> { // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); address_book.insert(private_addr.clone()); @@ -775,14 +672,11 @@ pub fn start( .await }) })); - let (operational_message_tx, operational_message_rx) = - mpsc::unbounded_channel::(); let proxy_handle = block_on(ProxyHandle::start( proxy_addr, event_tx.clone(), pending_event_count.clone(), - operational_message_tx.clone(), )) .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; @@ -792,12 +686,11 @@ pub fn start( config, pending_event_count, proxy_handle, - operational_message_tx, training_script_state_tx, stop_signal, }); - Ok(operational_message_rx) + Ok(()) } impl SimNet { @@ -1488,45 +1381,6 @@ edges: assert_eq!(records.unwrap().first().unwrap(), &expected_record); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_operational_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let mut operational_message_rx = start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let spawn_mesh = SpawnMesh { - system_addr: "unix!@system".parse().unwrap(), - controller_actor_id: id!(controller_world[0].actor), - worker_world: id!(worker_world), - }; - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh.clone()); - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let external_message = MessageEnvelope::new_unknown(port_id, serialized_proxy_message); - - // Send the message to the simnet. - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - let received_operational_message = operational_message_rx.recv().await.unwrap(); - - // Check the received message. - assert_eq!(received_operational_message, operational_message); - } - #[tokio::test] async fn test_sim_sleep() { start( diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index 85657d40..eb93c1c5 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -8,24 +8,21 @@ #![cfg(feature = "tensor_engine")] +use std::sync::Arc; + use anyhow::anyhow; -use hyperactor::PortId; -use hyperactor::attrs::Attrs; +use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::Tx; -use hyperactor::channel::dial; -use hyperactor::data::Serialized; -use hyperactor::id; -use hyperactor::mailbox::MessageEnvelope; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::ProxyMessage; -use hyperactor::simnet::SpawnMesh; +use hyperactor::channel::ChannelTransport; +use hyperactor::simnet; use hyperactor::simnet::TrainingScriptState; +use hyperactor::simnet::simnet_handle; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_simulator_lib::bootstrap::bootstrap; +use monarch_simulator_lib::simulator::TensorEngineSimulator; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use tokio::sync::Mutex; /// A wrapper around [ndslice::Slice] to expose it to python. /// It is a compact representation of indices into the flat @@ -39,102 +36,99 @@ use pyo3::prelude::*; )] #[derive(Clone)] pub(crate) struct SimulatorClient { - proxy_addr: ChannelAddr, -} - -fn wrap_operational_message(operational_message: OperationalMessage) -> MessageEnvelope { - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let sender_id = id!(simulator_client[0].sender_actor); - // a dummy port ID. We are delivering message with low level mailbox. - // The port ID is not used. - let port_id = PortId(id!(simulator[0].actor), 0); - MessageEnvelope::new(sender_id, port_id, serialized_proxy_message, Attrs::new()) -} - -#[pyfunction] -fn bootstrap_simulator_backend( - py: Python, - system_addr: String, - proxy_addr: String, - world_size: i32, -) -> PyResult<()> { - signal_safe_block_on(py, async move { - match bootstrap( - system_addr.parse().unwrap(), - proxy_addr.parse().unwrap(), - world_size as usize, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => Err(PyRuntimeError::new_err(err.to_string())), - } - })? + inner: Arc>, + world_size: usize, } -fn set_training_script_state(state: TrainingScriptState, proxy_addr: ChannelAddr) -> PyResult<()> { - let operational_message = OperationalMessage::SetTrainingScriptState(state); - let external_message = wrap_operational_message(operational_message); - let tx = dial(proxy_addr).map_err(|err| anyhow!(err))?; - tx.post(external_message); +fn set_training_script_state(state: TrainingScriptState) -> PyResult<()> { + simnet_handle() + .map_err(|e| anyhow!(e))? + .set_training_script_state(state); Ok(()) } #[pymethods] impl SimulatorClient { #[new] - fn new(proxy_addr: &str) -> PyResult { - Ok(Self { - proxy_addr: proxy_addr + fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { + signal_safe_block_on(py, async move { + let system_addr = system_addr .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?, - }) + .map_err(|err| PyValueError::new_err(err.to_string()))?; + + let ChannelAddr::Sim(system_sim_addr) = &system_addr else { + return Err(PyValueError::new_err(format!( + "bootstrap address should be a sim address: {}", + system_addr + ))); + }; + + simnet::start( + ChannelAddr::any(ChannelTransport::Unix), + system_sim_addr.proxy().clone(), + 1000, + ) + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + + Ok(Self { + inner: Arc::new(Mutex::new( + TensorEngineSimulator::new(system_addr) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + )), + world_size: world_size as usize, + }) + })? } - fn kill_world(&self, world_name: &str) -> PyResult<()> { - let operational_message = OperationalMessage::KillWorld(world_name.to_string()); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + fn kill_world(&self, py: Python, world_name: &str) -> PyResult<()> { + let simulator = self.inner.clone(); + let world_name = world_name.to_string(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .kill_world(&world_name) + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn spawn_mesh( &self, + py: Python, system_addr: &str, controller_actor_id: &str, worker_world: &str, ) -> PyResult<()> { - let spawn_mesh = SpawnMesh::new( - system_addr.parse().unwrap(), - controller_actor_id.parse().unwrap(), - worker_world.parse().unwrap(), - ); - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + let simulator = self.inner.clone(); + let world_size = self.world_size; + let system_addr = system_addr.parse::().unwrap(); + let worker_world = worker_world.parse::().unwrap(); + let controller_actor_id = controller_actor_id.parse().unwrap(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) + .await + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn set_training_script_state_running(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Running, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Running) } fn set_training_script_state_waiting(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Waiting, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Waiting) } } pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { simulator_client_mod.add_class::()?; - let f = wrap_pyfunction!(bootstrap_simulator_backend, simulator_client_mod)?; - f.setattr( - "__module__", - "monarch._rust_bindings.monarch_extension.simulator_client", - )?; - simulator_client_mod.add_function(f)?; Ok(()) } diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml index 863b5384..e28372ed 100644 --- a/monarch_simulator/Cargo.toml +++ b/monarch_simulator/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/monarch_simulator:[monarch_simulator,monarch_simulator_lib] +# @generated by autocargo from //monarch/monarch_simulator:monarch_simulator_lib [package] name = "monarch_simulator_lib" @@ -7,14 +7,9 @@ authors = ["Meta"] edition = "2021" license = "BSD-3-Clause" -[[bin]] -name = "monarch_simulator" -path = "src/main.rs" - [dependencies] anyhow = "1.0.98" async-trait = "0.1.86" -clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller" } dashmap = { version = "5.5.3", features = ["rayon", "serde"] } futures = { version = "0.3.30", features = ["async-await", "compat"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index e946fd63..6ba3000d 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -20,23 +20,15 @@ use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; -use hyperactor::simnet; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::SpawnMesh; -use hyperactor::simnet::simnet_handle; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; use hyperactor_multiprocess::proc_actor::spawn; use hyperactor_multiprocess::system::ServerHandle; use hyperactor_multiprocess::system_actor::ProcLifecycleMode; use monarch_messages::worker::Factory; -use tokio::sync::Mutex; -use tokio::sync::mpsc::UnboundedReceiver; -use tokio::task::JoinHandle; use torch_sys::Layout; use torch_sys::ScalarType; -use crate::simulator::Simulator; use crate::worker::Fabric; use crate::worker::MockWorkerParams; use crate::worker::WorkerActor; @@ -184,61 +176,3 @@ pub async fn spawn_sim_worker( .await?; Ok(bootstrap.proc_actor) } - -/// Bootstrap the simulation. Spawns the system, controllers, and workers. -/// Args: -/// system_addr: The address of the system actor. -pub async fn bootstrap( - system_addr: ChannelAddr, - proxy_addr: ChannelAddr, - world_size: usize, -) -> Result> { - // TODO: enable supervision events. - let mut operational_message_rx = simnet::start(system_addr.clone(), proxy_addr, 1000)?; - let simulator = Arc::new(Mutex::new(Simulator::new(system_addr).await?)); - let operational_listener_handle = { - let simulator = simulator.clone(); - tokio::spawn(async move { - handle_operational_message(&mut operational_message_rx, simulator, world_size).await - }) - }; - - Ok(operational_listener_handle) -} - -async fn handle_operational_message( - operational_message_rx: &mut UnboundedReceiver, - simulator: Arc>, - world_size: usize, -) { - while let Some(msg) = operational_message_rx.recv().await { - tracing::info!("received operational message: {:?}", msg); - match msg { - OperationalMessage::SpawnMesh(SpawnMesh { - system_addr, - controller_actor_id, - worker_world, - }) => { - if let Err(e) = simulator - .lock() - .await - .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) - .await - { - tracing::error!("failed to spawn mesh: {:?}", e); - } - } - OperationalMessage::KillWorld(world_id) => { - if let Err(e) = simulator.lock().await.kill_world(&world_id) { - tracing::error!("failed to kill world: {:?}", e); - } - } - OperationalMessage::SetTrainingScriptState(state) => match simnet_handle() { - Ok(handle) => handle.set_training_script_state(state), - Err(e) => { - tracing::error!("failed to set training script state: {:?}", e); - } - }, - } - } -} diff --git a/monarch_simulator/src/main.rs b/monarch_simulator/src/main.rs deleted file mode 100644 index 12169d91..00000000 --- a/monarch_simulator/src/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! A binary to launch the simulated Monarch controller along with necessary environment. -use std::process::ExitCode; - -use anyhow::Result; -use clap::Parser; -use hyperactor::channel::ChannelAddr; -use monarch_simulator_lib::bootstrap::bootstrap; - -#[derive(Debug, Parser)] -struct Args { - #[arg(short, long)] - system_addr: ChannelAddr, - #[arg(short, long)] - proxy_addr: ChannelAddr, -} - -const TITLE: &str = r#" -****************************************************** -* * -* ____ ___ __ __ _ _ _ _ _____ ___ ____ * -*/ ___|_ _| \/ | | | | | / \|_ _/ _ \| _ \ * -*\___ \| || |\/| | | | | | / _ \ | || | | | |_) |* -* ___) | || | | | |_| | |___ / ___ \| || |_| | _ < * -*|____/___|_| |_|\___/|_____/_/ \_\_| \___/|_| \_\* -* * -****************************************************** -"#; - -#[tokio::main] -async fn main() -> Result { - eprintln!("{}", TITLE); - hyperactor::initialize_with_current_runtime(); - let args = Args::parse(); - - let system_addr = args.system_addr.clone(); - let proxy_addr = args.proxy_addr.clone(); - tracing::info!("starting Monarch simulation"); - - let operational_listener_handle = bootstrap(system_addr, proxy_addr, 1).await?; - - operational_listener_handle - .await - .expect("simulator exited unexpectedly"); - - Ok(ExitCode::SUCCESS) -} diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 10fa9fd7..97c02b77 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -26,13 +26,13 @@ use crate::bootstrap::spawn_system; /// The simulator manages all of the meshes and the system handle. #[derive(Debug)] -pub struct Simulator { +pub struct TensorEngineSimulator { /// A map from world name to actor handles in that world. worlds: HashMap>>, system_handle: ServerHandle, } -impl Simulator { +impl TensorEngineSimulator { pub async fn new(system_addr: ChannelAddr) -> Result { Ok(Self { worlds: HashMap::new(), @@ -92,7 +92,7 @@ impl Simulator { /// IntoFuture allows users to await the handle. The future resolves when /// the simulator itself has all of the actor handles stopped. -impl IntoFuture for Simulator { +impl IntoFuture for TensorEngineSimulator { type Output = (); type IntoFuture = BoxFuture<'static, Self::Output>; @@ -145,7 +145,9 @@ mod tests { let system_addr = format!("sim!unix!@system,{}", &proxy) .parse::() .unwrap(); - let mut simulator = super::Simulator::new(system_addr.clone()).await.unwrap(); + let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) + .await + .unwrap(); let mut controller_actor_ids = vec![]; let mut worker_actor_ids = vec![]; let n_meshes = 2; diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi index 9cea9d93..6b06c274 100644 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi @@ -13,10 +13,11 @@ class SimulatorClient: It is a client to communicate with the simulator service. Arguments: - - `proxy_addr`: Address of the simulator's proxy server. + - `system_addr`: Address of the system. + - `world_size`: Number of workers in a given mesh. """ - def __init__(self, proxy_addr: str) -> None: ... + def __init__(self, system_addr: str, world_size: int) -> None: ... def kill_world(self, world_name: str) -> None: """ Kill the world with the given name. @@ -51,11 +52,3 @@ class SimulatorClient: backend to resolve a future """ ... - -def bootstrap_simulator_backend( - system_addr: str, proxy_addr: str, world_size: int -) -> None: - """ - Bootstrap the simulator backend on the current process - """ - ... diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index b91361bf..8c304efd 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -31,7 +31,6 @@ ) from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension - bootstrap_simulator_backend, SimulatorClient, ) @@ -76,7 +75,6 @@ def sim_mesh( bootstrap: Bootstrap = Bootstrap( n_meshes, mesh_world_state, - proxy_addr=proxy_addr, world_size=hosts * gpus_per_host, ) @@ -181,7 +179,6 @@ def __init__( self, num_meshes: int, mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - proxy_addr: Optional[str] = None, world_size: int = 1, ) -> None: """ @@ -199,17 +196,15 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy" + proxy_addr = f"unix!@{_random_id()}-proxy" self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr: str = ( + self.client_listen_addr = f"sim!unix!@client,{client_proxy_addr}" + self.client_bootstrap_addr = ( f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" ) - bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size) - self._simulator_client = SimulatorClient(proxy_addr) + self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): mesh_name: str = f"mesh_{i}" controller_world: str = f"{mesh_name}_controller" @@ -235,7 +230,9 @@ def spawn_mesh(self, mesh_world: MeshWorld) -> None: worker_world, controller_id = mesh_world controller_world = controller_id.world_name self._simulator_client.spawn_mesh( - self.bootstrap_addr, f"{controller_world}[0].root", worker_world + self.bootstrap_addr, + f"{controller_world}[0].root", + worker_world, )