diff --git a/controller/src/lib.rs b/controller/src/lib.rs index a2d1f54d..36755094 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -1559,20 +1559,13 @@ mod tests { #[tokio::test] async fn test_sim_supervision_failure() { // Start system actor. - let system_addr = ChannelAddr::any(ChannelTransport::Unix); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); simnet::simnet_handle() .unwrap() .set_training_script_state(simnet::TrainingScriptState::Waiting); let system_sim_addr = - ChannelAddr::Sim(SimAddr::new(system_addr, proxy_addr.clone()).unwrap()); + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); // Set very long supervision_update_timeout let server_handle = System::serve( system_sim_addr.clone(), @@ -1588,9 +1581,8 @@ mod tests { // Bootstrap the controller let controller_id = id!(controller[0].root); let proc_id = id!(world[0]); - let controller_proc_listen_addr = ChannelAddr::Sim( - SimAddr::new(ChannelAddr::any(ChannelTransport::Unix), proxy_addr).unwrap(), - ); + let controller_proc_listen_addr = + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let (_, actor_ref) = ControllerActor::bootstrap( controller_id.clone(), diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index b98ec1f7..cdf9e662 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -236,7 +236,7 @@ pub enum ChannelTransport { Local, /// Sim is a simulated channel for testing. - Sim(/*proxy address:*/ ChannelAddr), + Sim(/*simulated transport:*/ Box), /// Transport over unix domain socket. Unix, @@ -368,7 +368,7 @@ impl ChannelAddr { Self::MetaTls(hostname, 0) } ChannelTransport::Local => Self::Local(0), - ChannelTransport::Sim(proxy) => sim::any(proxy), + ChannelTransport::Sim(transport) => sim::any(*transport), // This works because the file will be deleted but we know we have a unique file by this point. ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()), } @@ -380,7 +380,7 @@ impl ChannelAddr { Self::Tcp(_) => ChannelTransport::Tcp, Self::MetaTls(_, _) => ChannelTransport::MetaTls, Self::Local(_) => ChannelTransport::Local, - Self::Sim(addr) => ChannelTransport::Sim(addr.proxy().clone()), + Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, } } @@ -637,10 +637,9 @@ mod tests { } for (raw, parsed) in cases_ok.iter().zip(src_ok.clone()).map(|(a, _)| { - let proxy_str = "unix!@proxy_a"; ( - format!("sim!{},{}", a.0, &proxy_str), - ChannelAddr::Sim(SimAddr::new(a.1.clone(), proxy_str.parse().unwrap()).unwrap()), + format!("sim!{}", a.0), + ChannelAddr::Sim(SimAddr::new(a.1.clone()).unwrap()), ) }) { assert_eq!(raw.parse::().unwrap(), parsed); diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index be97ca4b..73a663a5 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -39,25 +39,6 @@ lazy_static! { } static SIM_LINK_BUF_SIZE: usize = 256; -#[derive( - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Ord, - PartialOrd, - Hash -)] -/// A channel address along with the address of the proxy for the process -pub struct AddressProxyPair { - /// The address. - pub address: ChannelAddr, - /// The address of the proxy for the process - pub proxy: ChannelAddr, -} - /// An address for a simulated channel. #[derive( Clone, @@ -71,36 +52,38 @@ pub struct AddressProxyPair { Hash )] pub struct SimAddr { - src: Option>, + src: Option>, /// The address. addr: Box, - /// The proxy address. - proxy: Box, + /// If source is the client + client: bool, } impl SimAddr { /// Creates a new SimAddr. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. /// Creates a new SimAddr without a source to be served - pub fn new(addr: ChannelAddr, proxy: ChannelAddr) -> Result { - Self::new_impl(None, addr, proxy) + pub fn new(addr: ChannelAddr) -> Result { + Self::new_impl(None, addr, false) } /// Creates a new directional SimAddr meant to convey a channel between two addresses. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. - pub fn new_with_src( - src: AddressProxyPair, - addr: ChannelAddr, - proxy: ChannelAddr, - ) -> Result { - Self::new_impl(Some(Box::new(src)), addr, proxy) + pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, false) + } + + /// Creates a new directional SimAddr meant to convey a channel between two addresses. + #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. + fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, true) } #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. fn new_impl( - src: Option>, + src: Option>, addr: ChannelAddr, - proxy: ChannelAddr, + client: bool, ) -> Result { if let ChannelAddr::Sim(_) = &addr { return Err(SimNetError::InvalidArg(format!( @@ -108,16 +91,10 @@ impl SimAddr { addr ))); } - if let ChannelAddr::Sim(_) = &proxy { - return Err(SimNetError::InvalidArg(format!( - "proxy cannot be a sim address, found {}", - proxy - ))); - } Ok(Self { src, addr: Box::new(addr), - proxy: Box::new(proxy), + client, }) } @@ -126,26 +103,22 @@ impl SimAddr { &self.addr } - /// Returns the proxy address. - pub fn proxy(&self) -> &ChannelAddr { - &self.proxy + /// Returns the source address + pub fn src(&self) -> &Option> { + &self.src } - /// Returns the source address and proxy. - pub fn src(&self) -> &Option> { - &self.src + /// The underlying transport we are simulating + pub fn transport(&self) -> ChannelTransport { + self.addr.transport() } } impl fmt::Display for SimAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.src { - None => write!(f, "{},{}", self.addr, self.proxy), - Some(src) => write!( - f, - "{},{},{},{}", - src.address, src.proxy, self.addr, self.proxy - ), + None => write!(f, "{}", self.addr), + Some(src) => write!(f, "{},{}", src, self.addr), } } } @@ -153,19 +126,15 @@ impl fmt::Display for SimAddr { /// Message Event that can be passed around in the simnet. #[derive(Debug)] pub(crate) struct MessageDeliveryEvent { - src_addr: Option, - dest_addr: AddressProxyPair, + src_addr: Option, + dest_addr: ChannelAddr, data: Serialized, duration_ms: u64, } impl MessageDeliveryEvent { /// Creates a new MessageDeliveryEvent. - pub fn new( - src_addr: Option, - dest_addr: AddressProxyPair, - data: Serialized, - ) -> Self { + pub fn new(src_addr: Option, dest_addr: ChannelAddr, data: Serialized) -> Self { Self { src_addr, dest_addr, @@ -198,16 +167,16 @@ impl Event for MessageDeliveryEvent { "Sending message from {} to {}", self.src_addr .as_ref() - .map_or("unknown".to_string(), |addr| addr.address.to_string()), - self.dest_addr.address.clone() + .map_or("unknown".to_string(), |addr| addr.to_string()), + self.dest_addr.clone() ) } async fn read_simnet_config(&mut self, topology: &Arc>) { if let Some(src_addr) = &self.src_addr { let edge = SimNetEdge { - src: src_addr.address.clone(), - dst: self.dest_addr.address.clone(), + src: src_addr.clone(), + dst: self.dest_addr.clone(), }; self.duration_ms = topology .lock() @@ -232,18 +201,18 @@ pub async fn update_config(config: simnet::NetworkConfig) -> anyhow::Result<(), } /// Returns a simulated channel address that is bound to "any" channel address. -pub(crate) fn any(proxy: ChannelAddr) -> ChannelAddr { +pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr { ChannelAddr::Sim(SimAddr { src: None, - addr: Box::new(ChannelAddr::any(proxy.transport())), - proxy: Box::new(proxy), + addr: Box::new(ChannelAddr::any(transport)), + client: false, }) } /// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`. pub fn parse(addr_string: &str) -> Result { - let re = Regex::new(r"([^,]+),([^,]+)(,([^,]+),([^,]+))?$").map_err(|err| { + let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| { ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err)) })?; @@ -261,26 +230,19 @@ pub fn parse(addr_string: &str) -> Result { } match parts.len() { - 2 => { + 1 => { let addr = parts[0].parse::()?; - let proxy = parts[1].parse::()?; - Ok(ChannelAddr::Sim(SimAddr::new(addr, proxy)?)) + Ok(ChannelAddr::Sim(SimAddr::new(addr)?)) } - 5 => { + 3 => { let src_addr = parts[0].parse::()?; - let src_proxy = parts[1].parse::()?; - let addr = parts[3].parse::()?; - let proxy = parts[4].parse::()?; - - Ok(ChannelAddr::Sim(SimAddr::new_with_src( - AddressProxyPair { - address: src_addr, - proxy: src_proxy, - }, - addr, - proxy, - )?)) + let addr = parts[2].parse::()?; + Ok(ChannelAddr::Sim(if parts[0] == "client" { + SimAddr::new_with_client_src(src_addr, addr) + } else { + SimAddr::new_with_src(src_addr, addr) + }?)) } _ => Err(ChannelError::InvalidAddress(addr_string.to_string())), } @@ -310,31 +272,22 @@ fn create_egress_sender( Ok(Arc::new(tx)) } -/// Check if the address is outside of the simulation. -#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. -fn is_external_addr(addr: &AddressProxyPair) -> anyhow::Result { - Ok(simnet_handle()?.proxy_addr() != &addr.proxy) -} - #[async_trait] -impl Dispatcher for SimDispatcher { +impl Dispatcher for SimDispatcher { async fn send( &self, - _src_addr: Option, - addr: AddressProxyPair, + _src_addr: Option, + addr: ChannelAddr, data: Serialized, ) -> Result<(), SimNetError> { self.dispatchers - .get(&addr.address) + .get(&addr) .ok_or_else(|| { - SimNetError::InvalidNode( - addr.address.to_string(), - anyhow::anyhow!("no dispatcher found"), - ) + SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found")) })? .send(data) .await - .map_err(|err| SimNetError::InvalidNode(addr.address.to_string(), err.into())) + .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into())) } } @@ -349,9 +302,10 @@ impl Default for SimDispatcher { #[derive(Debug)] pub(crate) struct SimTx { - src_addr: Option, - dst_addr: AddressProxyPair, + src_addr: Option, + dst_addr: ChannelAddr, status: watch::Receiver, // Default impl. Always reports `Active`. + client: bool, _phantom: PhantomData, } @@ -372,15 +326,14 @@ impl Tx for SimTx { }; match simnet_handle() { Ok(handle) => match &self.src_addr { - Some(src_addr) if src_addr.proxy != *handle.proxy_addr() => handle - .send_scheduled_event(ScheduledEvent { - event: Box::new(MessageDeliveryEvent::new( - self.src_addr.clone(), - self.dst_addr.clone(), - data, - )), - time: SimClock.millis_since_start(RealClock.now()), - }), + Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent { + event: Box::new(MessageDeliveryEvent::new( + self.src_addr.clone(), + self.dst_addr.clone(), + data, + )), + time: SimClock.millis_since_start(RealClock.now()), + }), _ => handle.send_event(Box::new(MessageDeliveryEvent::new( self.src_addr.clone(), self.dst_addr.clone(), @@ -393,7 +346,7 @@ impl Tx for SimTx { } fn addr(&self) -> ChannelAddr { - self.dst_addr.address.clone() + self.dst_addr.clone() } fn status(&self) -> &watch::Receiver { @@ -412,11 +365,9 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE Ok(SimTx { src_addr: dialer, - dst_addr: AddressProxyPair { - address: *addr.addr, - proxy: *addr.proxy, - }, + dst_addr: addr.addr().clone(), status, + client: addr.client, _phantom: PhantomData, }) } @@ -425,13 +376,10 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE /// The mpsc tx will be used to dispatch messages when it's time while /// the mpsc rx will be used by the above applications to handle received messages /// like any other channel. -/// A sim address has src and dst. Dispatchers are only indexed by dst address. +/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address. pub(crate) fn serve( sim_addr: SimAddr, ) -> anyhow::Result<(ChannelAddr, SimRx)> { - // Serves sim address at sim_addr.src and set up local proxy at sim_addr.src_proxy. - // Reversing the src and dst since the first element in the output tuple is the - // dialing address of this sim channel. So the served address is the dst. let (tx, rx) = mpsc::channel::(SIM_LINK_BUF_SIZE); // Add tx to sender dispatch. SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx); @@ -474,13 +422,7 @@ mod tests { let dst_ok = vec!["[::1]:1234", "tcp!127.0.0.1:8080", "local!123"]; let srcs_ok = vec!["[::2]:1234", "tcp!127.0.0.2:8080", "local!124"]; - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); // TODO: New NodeAdd event should do this for you.. for addr in dst_ok.iter().chain(srcs_ok.iter()) { @@ -490,16 +432,10 @@ mod tests { .bind(addr.parse::().unwrap()) .unwrap(); } - // Messages are transferred internally if only there's a local proxy and the - // dst proxy is the same as local proxy. for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) { let dst_addr = SimAddr::new_with_src( - AddressProxyPair { - address: src_addr.parse::().unwrap(), - proxy: proxy.clone(), - }, + src_addr.parse::().unwrap(), dst_addr.parse::().unwrap(), - proxy.clone(), ) .unwrap(); @@ -517,22 +453,14 @@ mod tests { async fn test_invalid_sim_addr() { let src = "sim!src"; let dst = "sim!dst"; - let src_proxy = "sim!src_proxy"; - let dst_proxy = "sim!dst_proxy"; - let sim_addr = format!("{},{},{},{}", src, src_proxy, dst, dst_proxy); + let sim_addr = format!("{},{}", src, dst); let result = parse(&sim_addr); assert!(matches!(result, Err(ChannelError::InvalidAddress(_)))); - - let dst = "unix!dst".parse::().unwrap(); - let dst_proxy = "sim!unix!a,unix!b".parse::().unwrap(); - let result = SimAddr::new(dst, dst_proxy); - // dst_proxy shouldn't be a sim address. - assert!(matches!(result, Err(SimNetError::InvalidArg(_)))); } #[tokio::test] async fn test_parse_sim_addr() { - let sim_addr = "sim!unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { @@ -540,42 +468,26 @@ mod tests { }; assert!(sim_addr.src().is_none()); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); - let sim_addr = "sim!unix!@src,unix!@proxy,unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@src,unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { panic!("Expected a sim address"); }; assert!(sim_addr.src().is_some()); - let src_pair = sim_addr.src().clone().unwrap(); - assert_eq!(src_pair.address.to_string(), "unix!@src"); - assert_eq!(src_pair.proxy.to_string(), "unix!@proxy"); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); } #[tokio::test] async fn test_realtime_frontier() { - let proxy: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); tokio::time::pause(); - let sim_addr = - SimAddr::new("unix!@dst".parse::().unwrap(), proxy.clone()).unwrap(); + let sim_addr = SimAddr::new("unix!@dst".parse::().unwrap()).unwrap(); let sim_addr_with_src = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@src".parse::().unwrap(), - proxy: proxy.clone(), - }, + "unix!@src".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy.clone(), ) .unwrap(); let (_, mut rx) = sim::serve::<()>(sim_addr.clone()).unwrap(); @@ -612,31 +524,17 @@ mod tests { #[tokio::test] async fn test_client_message_scheduled_realtime() { tokio::time::pause(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + start(); let controller_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@controller".parse::().unwrap(), - proxy: proxy_addr.clone(), - }, + "unix!@controller".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let controller_tx = sim::dial::<()>(controller_to_dst.clone()).unwrap(); - let client_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: ChannelAddr::any(ChannelTransport::Unix), - proxy: ChannelAddr::any(ChannelTransport::Unix), - }, + let client_to_dst = SimAddr::new_with_client_src( + "unix!@client".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let client_tx = sim::dial::<()>(client_to_dst).unwrap(); diff --git a/hyperactor/src/clock.rs b/hyperactor/src/clock.rs index 9c8fcbc4..b5d45bdb 100644 --- a/hyperactor/src/clock.rs +++ b/hyperactor/src/clock.rs @@ -341,12 +341,7 @@ mod tests { #[tokio::test] async fn test_sim_timeout() { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let res = SimClock .timeout(tokio::time::Duration::from_secs(10), async { SimClock.sleep(tokio::time::Duration::from_secs(5)).await; diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index 0f0c33a3..e2e7d5a2 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -2359,7 +2359,6 @@ mod tests { use crate::channel::ChannelTransport; use crate::channel::dial; use crate::channel::serve; - use crate::channel::sim::AddressProxyPair; use crate::channel::sim::SimAddr; use crate::clock::Clock; use crate::clock::RealClock; @@ -2560,23 +2559,12 @@ mod tests { #[tokio::test] async fn test_sim_client_server() { - let proxy = ChannelAddr::any(channel::ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); - let dst_addr = - SimAddr::new("local!1".parse::().unwrap(), proxy.clone()).unwrap(); + simnet::start(); + let dst_addr = SimAddr::new("local!1".parse::().unwrap()).unwrap(); let src_to_dst = ChannelAddr::Sim( SimAddr::new_with_src( - AddressProxyPair { - address: "local!0".parse::().unwrap(), - proxy: proxy.clone(), - }, + "local!0".parse::().unwrap(), dst_addr.addr().clone(), - dst_addr.proxy().clone(), ) .unwrap(), ); diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 0a87a14f..0e0224fa 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -26,40 +26,28 @@ use async_trait::async_trait; use dashmap::DashMap; use dashmap::DashSet; use enum_as_inner::EnumAsInner; -use futures::executor::block_on; use serde::Deserialize; use serde::Deserializer; use serde::Serialize; use serde::Serializer; use serde_with::serde_as; use tokio::sync::Mutex; -use tokio::sync::SetError; use tokio::sync::mpsc; -use tokio::sync::mpsc::Sender; use tokio::sync::mpsc::UnboundedReceiver; use tokio::sync::mpsc::UnboundedSender; -use tokio::sync::mpsc::error::SendError; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; -use tokio::time::timeout; use crate as hyperactor; // for macros use crate::ActorId; use crate::Mailbox; -use crate::Named; use crate::OncePortRef; -use crate::WorldId; -use crate::channel; use crate::channel::ChannelAddr; -use crate::channel::Rx; -use crate::channel::sim::AddressProxyPair; -use crate::channel::sim::MessageDeliveryEvent; use crate::clock::Clock; use crate::clock::RealClock; use crate::clock::SimClock; use crate::data::Serialized; -use crate::mailbox::MessageEnvelope; static HANDLE: OnceLock = OnceLock::new(); @@ -309,22 +297,6 @@ pub enum SimNetError { #[error("timeout after {} ms: {}", .0.as_millis(), .1)] Timeout(Duration, String), - /// External node is trying to connect but proxy is not available. - #[error("proxy not available: {0}")] - ProxyNotAvailable(String), - - /// Unable to send message to the simulator. - #[error(transparent)] - OperationalMessageSendError(#[from] SendError), - - /// Setting the operational message sender which is already set. - #[error(transparent)] - OperationalMessageSenderSetError(#[from] SetError>), - - /// Missing OperationalMessageReceiver. - #[error("missing operational message receiver")] - MissingOperationalMessageReceiver, - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -358,11 +330,6 @@ pub struct SimNetHandle { event_tx: UnboundedSender<(Box, bool, Option)>, config: Arc>, pending_event_count: Arc, - /// Handle to a running proxy server that forwards external messages - /// into the simnet. - proxy_handle: ProxyHandle, - /// A sender to forward simulator operational messages. - operational_message_tx: UnboundedSender, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -434,8 +401,6 @@ impl SimNetHandle { /// Close the simulator, processing pending messages before /// completing the returned future. pub async fn close(&self) -> Result, JoinError> { - // Stop the proxy if there is one. - self.proxy_handle.stop().await?; // Signal the simnet loop to stop self.stop_signal.store(true, Ordering::SeqCst); @@ -480,118 +445,10 @@ impl SimNetHandle { "timeout waiting for received events to be scheduled".to_string(), )) } - - /// Returns the external address of the simnet. - pub fn proxy_addr(&self) -> &ChannelAddr { - &self.proxy_handle.addr - } } pub(crate) type Topology = DashMap; -/// The message to spawn a simulated mesh. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)] -pub struct SpawnMesh { - /// The system address. - pub system_addr: ChannelAddr, - /// The controller actor ID. - pub controller_actor_id: ActorId, - /// The worker world. - pub worker_world: WorldId, -} - -impl SpawnMesh { - /// Creates a new SpawnMesh. - pub fn new( - system_addr: ChannelAddr, - controller_actor_id: ActorId, - worker_world: WorldId, - ) -> Self { - Self { - system_addr, - controller_actor_id, - worker_world, - } - } -} - -/// An OperationalMessage is a message to control the simulator to do tasks such as -/// spawning or killing actors. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub enum OperationalMessage { - /// Kill the world with given world_id. - KillWorld(String), - /// Spawn actors in a mesh. - SpawnMesh(SpawnMesh), - /// Update training script state. - SetTrainingScriptState(TrainingScriptState), -} - -/// Message Event that can be sent to the simulator. -#[derive(Debug)] -pub struct SimOperation { - /// Sender to send OperationalMessage to the simulator. - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, -} - -impl SimOperation { - /// Creates a new SimOperation. - pub fn new( - operational_message_tx: UnboundedSender, - operational_message: OperationalMessage, - ) -> Self { - Self { - operational_message_tx, - operational_message, - } - } -} - -#[async_trait] -impl Event for SimOperation { - async fn handle(&self) -> Result<(), SimNetError> { - self.operational_message_tx - .send(self.operational_message.clone())?; - Ok(()) - } - - fn duration_ms(&self) -> u64 { - 0 - } - - fn summary(&self) -> String { - format!("SimOperation: {:?}", self.operational_message) - } -} - -/// A ProxyMessage is a message that SimNet proxy receives. -/// The message may requests the SimNet to send the payload in the message field from -/// src to dst if addr field exists. -/// Or handle the payload in the message field if addr field is None, indicating that -/// this is a self-handlable message. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub struct ProxyMessage { - sender_addr: Option, - dest_addr: Option, - data: Serialized, -} - -impl ProxyMessage { - /// Creates a new ForwardMessage. - pub fn new( - sender_addr: Option, - dest_addr: Option, - data: Serialized, - ) -> Self { - Self { - sender_addr, - dest_addr, - data, - } - } -} - /// Configure network topology for the simnet pub struct SimNetConfig { // For now, we assume the network is fully connected @@ -604,31 +461,6 @@ pub struct SimNetConfig { /// The network is represented as a graph of nodes. /// The graph is represented as a map of edges. /// The network also has a cloud of inflight messages -/// SimNet also serves a proxy address to receive external traffic. This proxy address can handle -/// [`ProxyMessage`]s and forward the payload from src to dst. -/// -/// Example: -/// In this example, we send a ForwardMessage to the proxy_addr. SimNet will handle the message and -/// forward the payload from src to dst. -/// ```ignore -/// let nw_handle = start("local!0".parse().unwrap(), 1000, true, Some(gen_event_fcn)) -/// .await -/// .unwrap(); -/// let proxy_addr = nw_handle.proxy_addr().clone(); -/// let tx = crate::channel::dial(proxy_addr).unwrap(); -/// let src_to_dst_msg = MessageEnvelope::new_unknown( -/// port_id.clone(), -/// Serialized::serialize(&"hola".to_string()).unwrap(), -/// ); -/// let forward_message = ForwardMessage::new( -/// "unix!@src".parse::().unwrap(), -/// "unix!@dst".parse::().unwrap(), -/// src_to_dst_msg -/// ); -/// let external_message = -/// MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); -/// tx.send(external_message).await.unwrap(); -/// ``` pub struct SimNet { config: Arc>, address_book: DashSet, @@ -639,112 +471,15 @@ pub struct SimNet { pending_event_count: Arc, } -/// A proxy to bridge external nodes and the SimNet. -struct ProxyHandle { - join_handle: Mutex>>, - stop_signal: Arc, - addr: ChannelAddr, -} - -impl ProxyHandle { - /// Starts an proxy server to handle external [`ForwardMessage`]s. It will forward the payload inside - /// the [`ForwardMessage`] from src to dst in the SimNet. - /// Args: - /// proxy_addr: address to listen - /// event_tx: a channel to send events to the SimNet - /// pending_event_count: a counter to keep track of the number of pending events - /// to_event: a function that specifies how to generate an Event from a forward message - async fn start( - proxy_addr: ChannelAddr, - event_tx: UnboundedSender<(Box, bool, Option)>, - pending_event_count: Arc, - operational_message_tx: UnboundedSender, - ) -> anyhow::Result { - let (addr, mut rx) = channel::serve::(proxy_addr).await?; - tracing::info!("SimNet serving external traffic on {}", &addr); - let stop_signal = Arc::new(AtomicBool::new(false)); - - let join_handle = { - let stop_signal = stop_signal.clone(); - tokio::spawn(async move { - 'outer: loop { - // timeout the wait to enable stop signal checking at least every 100ms. - #[allow(clippy::disallowed_methods)] - if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { - let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - let event: Box = match proxy_message.dest_addr { - Some(dest_addr) => Box::new(MessageDeliveryEvent::new( - proxy_message.sender_addr, - dest_addr, - proxy_message.data, - )), - None => { - let operational_message: OperationalMessage = - proxy_message.data.deserialized().unwrap(); - Box::new(SimOperation::new( - operational_message_tx.clone(), - operational_message, - )) - } - }; - - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - } - if stop_signal.load(Ordering::SeqCst) { - eprintln!("stopping external traffic handler"); - break 'outer; - } - } - }) - }; - Ok(Self { - join_handle: Mutex::new(Some(join_handle)), - stop_signal, - addr, - }) - } - - /// Stop the proxy. - async fn stop(&self) -> Result<(), JoinError> { - self.stop_signal.store(true, Ordering::SeqCst); - let mut guard = self.join_handle.lock().await; - if let Some(handle) = guard.take() { - handle.await - } else { - Ok(()) - } - } -} - /// Starts a sim net. /// Args: -/// private_addr: an internal address to receive operational messages such as NodeJoinEvent /// max_duration_ms: an optional config to override default settings of the network latency -/// enable_record: a flag to enable recording of message delivery records -pub fn start( - private_addr: ChannelAddr, - proxy_addr: ChannelAddr, - max_duration_ms: u64, -) -> anyhow::Result> { +pub fn start() { + let max_duration_ms = 1000 * 10; // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); - address_book.insert(private_addr.clone()); let topology = DashMap::new(); - topology.insert( - SimNetEdge { - src: private_addr.clone(), - dst: private_addr, - }, - SimNetEdgeInfo { - latency: Duration::from_millis(1), - }, - ); - let config = Arc::new(Mutex::new(SimNetConfig { topology })); let (training_script_state_tx, training_script_state_rx) = @@ -760,7 +495,7 @@ pub fn start( let stop_signal = stop_signal.clone(); tokio::spawn(async move { - let mut net = SimNet { + SimNet { config, address_book, state: State { @@ -770,34 +505,20 @@ pub fn start( max_latency: Duration::from_millis(max_duration_ms), records: Vec::new(), pending_event_count, - }; - net.run(event_rx, training_script_state_rx, stop_signal) - .await + } + .run(event_rx, training_script_state_rx, stop_signal) + .await }) })); - let (operational_message_tx, operational_message_rx) = - mpsc::unbounded_channel::(); - - let proxy_handle = block_on(ProxyHandle::start( - proxy_addr, - event_tx.clone(), - pending_event_count.clone(), - operational_message_tx.clone(), - )) - .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; HANDLE.get_or_init(|| SimNetHandle { join_handle, event_tx, config, pending_event_count, - proxy_handle, - operational_message_tx, training_script_state_tx, stop_signal, }); - - Ok(operational_message_rx) } impl SimNet { @@ -1188,17 +909,7 @@ mod tests { #[tokio::test] async fn test_handle_instantiation() { - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - assert!( - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .is_ok() - ); + start(); simnet_handle().unwrap().close().await.unwrap(); } @@ -1206,12 +917,7 @@ mod tests { async fn test_simnet_config() { // Tests that we can create a simnet, config latency between two node and deliver // the message with configured latency. - start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); let latency = Duration::from_millis(1000); @@ -1228,9 +934,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr.clone()).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); let msg = Box::new(MessageDeliveryEvent::new( alice, bob, @@ -1256,12 +961,7 @@ mod tests { #[tokio::test] async fn test_simnet_debounce() { let default_addr = "local!0".parse::().unwrap(); - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); @@ -1278,10 +978,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); // Rapidly send 10 messages expecting that each one debounces the processing for _ in 0..10 { @@ -1318,13 +1016,7 @@ mod tests { #[tokio::test] async fn test_sim_dispatch() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); let sender = Some(TestDispatcher::default()); let mut addresses: Vec = Vec::new(); // // Create a simple network of 4 nodes. @@ -1341,11 +1033,10 @@ mod tests { .map(|s| Serialized::serialize(&s.to_string()).unwrap()) .collect(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let addr_0 = SimAddr::new(addresses[0].clone(), proxy_addr.clone()).unwrap(); - let addr_1 = SimAddr::new(addresses[1].clone(), proxy_addr.clone()).unwrap(); - let addr_2 = SimAddr::new(addresses[2].clone(), proxy_addr.clone()).unwrap(); - let addr_3 = SimAddr::new(addresses[3].clone(), proxy_addr.clone()).unwrap(); + let addr_0 = SimAddr::new(addresses[0].clone()).unwrap(); + let addr_1 = SimAddr::new(addresses[1].clone()).unwrap(); + let addr_2 = SimAddr::new(addresses[2].clone()).unwrap(); + let addr_3 = SimAddr::new(addresses[3].clone()).unwrap(); let one = Box::new(MessageDeliveryEvent::new( addr_0.clone(), addr_1.clone(), @@ -1393,20 +1084,20 @@ mod tests { #[tokio::test] async fn test_read_config_from_yaml() { let yaml = r#" -edges: - - src: local!0 - dst: local!1 - metadata: - latency: 1 - - src: local!0 - dst: local!2 - metadata: - latency: 2 - - src: local!1 - dst: local!2 - metadata: - latency: 3 -"#; + edges: + - src: local!0 + dst: local!1 + metadata: + latency: 1 + - src: local!0 + dst: local!2 + metadata: + latency: 2 + - src: local!1 + dst: local!2 + metadata: + latency: 3 + "#; let config = NetworkConfig::from_yaml(yaml).unwrap(); assert_eq!(config.edges.len(), 3); assert_eq!( @@ -1438,113 +1129,9 @@ edges: assert_eq!(config.edges[2].metadata.latency, Duration::from_secs(3)); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_external_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let src_to_dst_msg = Serialized::serialize(&"hola".to_string()).unwrap(); - let src = random_abstract_addr(); - let dst = random_abstract_addr(); - let src_and_proxy = Some(AddressProxyPair { - address: src.clone(), - proxy: proxy_addr.clone(), - }); - let dst_and_proxy = AddressProxyPair { - address: dst.clone(), - proxy: proxy_addr.clone(), - }; - let forward_message = ProxyMessage::new(src_and_proxy, Some(dst_and_proxy), src_to_dst_msg); - let external_message = - MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - simnet_handle() - .unwrap() - .flush(Duration::from_millis(1000)) - .await - .unwrap(); - let records = simnet_handle().unwrap().close().await; - assert!(records.as_ref().unwrap().len() == 1); - let expected_record = SimulatorEventRecord { - summary: format!("Sending message from {} to {}", src, dst), - start_at: 0, - end_at: 1, - }; - assert_eq!(records.unwrap().first().unwrap(), &expected_record); - } - - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_operational_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let mut operational_message_rx = start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let spawn_mesh = SpawnMesh { - system_addr: "unix!@system".parse().unwrap(), - controller_actor_id: id!(controller_world[0].actor), - worker_world: id!(worker_world), - }; - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh.clone()); - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let external_message = MessageEnvelope::new_unknown(port_id, serialized_proxy_message); - - // Send the message to the simnet. - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - let received_operational_message = operational_message_rx.recv().await.unwrap(); - - // Check the received message. - assert_eq!(received_operational_message, operational_message); - } - #[tokio::test] async fn test_sim_sleep() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); - - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - let _ = start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let start = SimClock.now(); assert_eq!(SimClock.millis_since_start(start), 0); @@ -1557,12 +1144,7 @@ edges: #[tokio::test] async fn test_torch_op() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let args_string = "1, 2".to_string(); let kwargs_string = "a=2".to_string(); diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index 69bb7ca1..464f8e82 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -987,4 +987,10 @@ mod tests { buck_resources::get("monarch/hyperactor_mesh/bootstrap").unwrap() ))); } + + mod sim { + use crate::alloc::sim::SimAllocator; + + actor_mesh_test_suite!(SimAllocator::new_and_start_simnet()); + } } diff --git a/hyperactor_mesh/src/alloc.rs b/hyperactor_mesh/src/alloc.rs index 9c36f525..cfd758b4 100644 --- a/hyperactor_mesh/src/alloc.rs +++ b/hyperactor_mesh/src/alloc.rs @@ -13,6 +13,7 @@ pub mod local; pub(crate) mod logtailer; pub mod process; pub mod remoteprocess; +pub mod sim; use std::collections::HashMap; use std::fmt; diff --git a/hyperactor_mesh/src/alloc/local.rs b/hyperactor_mesh/src/alloc/local.rs index 5950eaa3..ceb4915f 100644 --- a/hyperactor_mesh/src/alloc/local.rs +++ b/hyperactor_mesh/src/alloc/local.rs @@ -75,10 +75,15 @@ pub struct LocalAlloc { todo_rx: mpsc::UnboundedReceiver, stopped: bool, failed: bool, + transport: ChannelTransport, } impl LocalAlloc { fn new(spec: AllocSpec) -> Self { + Self::new_with_transport(spec, ChannelTransport::Local) + } + + pub(crate) fn new_with_transport(spec: AllocSpec, transport: ChannelTransport) -> Self { let name = ShortUuid::generate(); let (todo_tx, todo_rx) = mpsc::unbounded_channel(); for rank in 0..spec.shape.slice().len() { @@ -94,6 +99,7 @@ impl LocalAlloc { todo_rx, stopped: false, failed: false, + transport, } } @@ -123,7 +129,7 @@ impl LocalAlloc { &self.name } - fn size(&self) -> usize { + pub(crate) fn size(&self) -> usize { self.spec.shape.slice().len() } } @@ -249,7 +255,7 @@ impl Alloc for LocalAlloc { } fn transport(&self) -> ChannelTransport { - ChannelTransport::Local + self.transport.clone() } async fn stop(&mut self) -> Result<(), AllocatorError> { diff --git a/hyperactor_mesh/src/alloc/sim.rs b/hyperactor_mesh/src/alloc/sim.rs new file mode 100644 index 00000000..f680d5a2 --- /dev/null +++ b/hyperactor_mesh/src/alloc/sim.rs @@ -0,0 +1,122 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +//! Support for allocating procs in the local process with simulated channels. + +#![allow(dead_code)] // until it is used outside of testing + +use async_trait::async_trait; +use hyperactor::WorldId; +use hyperactor::channel::ChannelAddr; +use hyperactor::channel::ChannelTransport; +use hyperactor::mailbox::MailboxServerHandle; +use hyperactor::proc::Proc; +use ndslice::Shape; + +use super::ProcStopReason; +use crate::alloc::Alloc; +use crate::alloc::AllocSpec; +use crate::alloc::Allocator; +use crate::alloc::AllocatorError; +use crate::alloc::LocalAlloc; +use crate::alloc::ProcState; +use crate::shortuuid::ShortUuid; + +/// An allocator that runs procs in the local process with network traffic going through simulated channels. +/// Other than transport, the underlying implementation is an inner LocalAlloc. +pub struct SimAllocator; + +#[async_trait] +impl Allocator for SimAllocator { + type Alloc = SimAlloc; + + async fn allocate(&mut self, spec: AllocSpec) -> Result { + Ok(SimAlloc::new(spec)) + } +} + +impl SimAllocator { + #[cfg(test)] + pub(crate) fn new_and_start_simnet() -> Self { + hyperactor::simnet::start(); + Self + } +} + +struct SimProc { + proc: Proc, + addr: ChannelAddr, + handle: MailboxServerHandle, +} + +/// A simulated allocation. It is a collection of procs that are running in the local process. +pub struct SimAlloc { + inner: LocalAlloc, +} + +impl SimAlloc { + fn new(spec: AllocSpec) -> Self { + Self { + inner: LocalAlloc::new_with_transport( + spec, + ChannelTransport::Sim(Box::new(ChannelTransport::Unix)), + ), + } + } + /// A chaos monkey that can be used to stop procs at random. + pub(crate) fn chaos_monkey(&self) -> impl Fn(usize, ProcStopReason) + 'static { + self.inner.chaos_monkey() + } + + /// A function to shut down the alloc for testing purposes. + pub(crate) fn stopper(&self) -> impl Fn() + 'static { + self.inner.stopper() + } + + pub(crate) fn name(&self) -> &ShortUuid { + self.inner.name() + } + + fn size(&self) -> usize { + self.inner.size() + } +} + +#[async_trait] +impl Alloc for SimAlloc { + async fn next(&mut self) -> Option { + self.inner.next().await + } + + fn shape(&self) -> &Shape { + self.inner.shape() + } + + fn world_id(&self) -> &WorldId { + self.inner.world_id() + } + + fn transport(&self) -> ChannelTransport { + self.inner.transport() + } + + async fn stop(&mut self) -> Result<(), AllocatorError> { + self.inner.stop().await + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_allocator_basic() { + hyperactor::simnet::start(); + crate::alloc::testing::test_allocator_basic(SimAllocator).await; + } +} diff --git a/hyperactor_multiprocess/src/ping_pong.rs b/hyperactor_multiprocess/src/ping_pong.rs index 676259cc..4c659153 100644 --- a/hyperactor_multiprocess/src/ping_pong.rs +++ b/hyperactor_multiprocess/src/ping_pong.rs @@ -14,9 +14,7 @@ mod tests { use hyperactor::ActorRef; use hyperactor::Mailbox; use hyperactor::channel::ChannelAddr; - use hyperactor::channel::ChannelTransport; use hyperactor::channel::sim; - use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::id; use hyperactor::reference::Index; @@ -36,16 +34,10 @@ mod tests { #[tokio::test] async fn test_sim_ping_pong() { let system_addr = "local!1".parse::().unwrap(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); - let system_sim_addr = SimAddr::new(system_addr.clone(), proxy_addr.clone()).unwrap(); + let system_sim_addr = SimAddr::new(system_addr.clone()).unwrap(); let server_handle = System::serve( ChannelAddr::Sim(system_sim_addr.clone()), Duration::from_secs(10), @@ -64,21 +56,14 @@ mod tests { let ping_actor_ref = spawn_proc_actor( 2, - proxy_addr.clone(), system_sim_addr.clone(), sys_mailbox.clone(), world_id.clone(), ) .await; - let pong_actor_ref = spawn_proc_actor( - 3, - proxy_addr.clone(), - system_sim_addr, - sys_mailbox.clone(), - world_id.clone(), - ) - .await; + let pong_actor_ref = + spawn_proc_actor(3, system_sim_addr, sys_mailbox.clone(), world_id.clone()).await; // Configure the simulation network. let simnet_config_yaml = r#" @@ -124,7 +109,6 @@ edges: async fn spawn_proc_actor( actor_index: Index, - proxy_addr: ChannelAddr, system_addr: SimAddr, sys_mailbox: Mailbox, world_id: WorldId, @@ -133,19 +117,11 @@ edges: .parse::() .unwrap(); - let proc_sim_addr = SimAddr::new(proc_addr.clone(), proxy_addr.clone()).unwrap(); + let proc_sim_addr = SimAddr::new(proc_addr.clone()).unwrap(); let proc_listen_addr = ChannelAddr::Sim(proc_sim_addr); let proc_id = world_id.proc_id(actor_index); let proc_to_system = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: proc_addr.clone(), - proxy: proxy_addr.clone(), - }, - system_addr.addr().clone(), - system_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(proc_addr.clone(), system_addr.addr().clone()).unwrap(), ); let bootstrap = ProcActor::bootstrap( proc_id, diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index fd477b94..c80160b7 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -38,7 +38,6 @@ use hyperactor::RefClient; use hyperactor::WorldId; use hyperactor::actor::Handler; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::clock::Clock; use hyperactor::clock::ClockKind; @@ -762,13 +761,9 @@ impl ReportingRouter { ChannelAddr::Sim( SimAddr::new_with_src( // source is the sender - AddressProxyPair { - address: sender_sim_addr.addr().clone(), - proxy: sender_sim_addr.proxy().clone(), - }, + sender_sim_addr.addr().clone(), // dest remains unchanged dest_sim_addr.addr().clone(), - dest_sim_addr.proxy().clone(), ) .unwrap(), ) @@ -2800,19 +2795,11 @@ mod tests { #[tokio::test] async fn test_update_sim_address() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + simnet::start(); let src_id = id!(proc[0].actor); - let src_addr = - ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap(), proxy.clone()).unwrap()); - let dst_addr = - ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap(), proxy.clone()).unwrap()); + let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap()); + let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap()); let (_, mut rx) = channel::serve::(src_addr.clone()) .await .unwrap(); @@ -2844,7 +2831,7 @@ mod tests { panic!("Expected sim address"); }; - assert_eq!(addr.src().clone().unwrap().address.to_string(), "unix!@src"); + assert_eq!(addr.src().clone().unwrap().to_string(), "unix!@src"); assert_eq!(addr.addr().to_string(), "unix!@dst"); } } diff --git a/monarch_extension/src/simulation_tools.rs b/monarch_extension/src/simulation_tools.rs index 6913967b..a73855e0 100644 --- a/monarch_extension/src/simulation_tools.rs +++ b/monarch_extension/src/simulation_tools.rs @@ -6,24 +6,16 @@ * LICENSE file in the root directory of this source tree. */ -use hyperactor::channel::ChannelAddr; -use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::SimClock; use hyperactor::simnet; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; #[pyfunction] #[pyo3(name = "start_event_loop")] pub fn start_simnet_event_loop(py: Python) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + simnet::start(); Ok(()) }) } diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index 85657d40..46d23e29 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -8,24 +8,21 @@ #![cfg(feature = "tensor_engine")] +use std::sync::Arc; + use anyhow::anyhow; -use hyperactor::PortId; -use hyperactor::attrs::Attrs; +use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::Tx; -use hyperactor::channel::dial; -use hyperactor::data::Serialized; -use hyperactor::id; -use hyperactor::mailbox::MessageEnvelope; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::ProxyMessage; -use hyperactor::simnet::SpawnMesh; +use hyperactor::channel::ChannelTransport; +use hyperactor::simnet; use hyperactor::simnet::TrainingScriptState; +use hyperactor::simnet::simnet_handle; use monarch_hyperactor::runtime::signal_safe_block_on; -use monarch_simulator_lib::bootstrap::bootstrap; +use monarch_simulator_lib::simulator::TensorEngineSimulator; use pyo3::exceptions::PyRuntimeError; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use tokio::sync::Mutex; /// A wrapper around [ndslice::Slice] to expose it to python. /// It is a compact representation of indices into the flat @@ -39,102 +36,87 @@ use pyo3::prelude::*; )] #[derive(Clone)] pub(crate) struct SimulatorClient { - proxy_addr: ChannelAddr, -} - -fn wrap_operational_message(operational_message: OperationalMessage) -> MessageEnvelope { - let serialized_operational_message = Serialized::serialize(&operational_message).unwrap(); - let proxy_message = ProxyMessage::new(None, None, serialized_operational_message); - let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap(); - let sender_id = id!(simulator_client[0].sender_actor); - // a dummy port ID. We are delivering message with low level mailbox. - // The port ID is not used. - let port_id = PortId(id!(simulator[0].actor), 0); - MessageEnvelope::new(sender_id, port_id, serialized_proxy_message, Attrs::new()) -} - -#[pyfunction] -fn bootstrap_simulator_backend( - py: Python, - system_addr: String, - proxy_addr: String, - world_size: i32, -) -> PyResult<()> { - signal_safe_block_on(py, async move { - match bootstrap( - system_addr.parse().unwrap(), - proxy_addr.parse().unwrap(), - world_size as usize, - ) - .await - { - Ok(_) => Ok(()), - Err(err) => Err(PyRuntimeError::new_err(err.to_string())), - } - })? + inner: Arc>, + world_size: usize, } -fn set_training_script_state(state: TrainingScriptState, proxy_addr: ChannelAddr) -> PyResult<()> { - let operational_message = OperationalMessage::SetTrainingScriptState(state); - let external_message = wrap_operational_message(operational_message); - let tx = dial(proxy_addr).map_err(|err| anyhow!(err))?; - tx.post(external_message); +fn set_training_script_state(state: TrainingScriptState) -> PyResult<()> { + simnet_handle() + .map_err(|e| anyhow!(e))? + .set_training_script_state(state); Ok(()) } #[pymethods] impl SimulatorClient { #[new] - fn new(proxy_addr: &str) -> PyResult { - Ok(Self { - proxy_addr: proxy_addr - .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?, - }) + fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { + signal_safe_block_on(py, async move { + simnet::start(); + + Ok(Self { + inner: Arc::new(Mutex::new( + TensorEngineSimulator::new( + system_addr + .parse::() + .map_err(|err| PyValueError::new_err(err.to_string()))?, + ) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + )), + world_size: world_size as usize, + }) + })? } - fn kill_world(&self, world_name: &str) -> PyResult<()> { - let operational_message = OperationalMessage::KillWorld(world_name.to_string()); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + fn kill_world(&self, py: Python, world_name: &str) -> PyResult<()> { + let simulator = self.inner.clone(); + let world_name = world_name.to_string(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .kill_world(&world_name) + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn spawn_mesh( &self, + py: Python, system_addr: &str, controller_actor_id: &str, worker_world: &str, ) -> PyResult<()> { - let spawn_mesh = SpawnMesh::new( - system_addr.parse().unwrap(), - controller_actor_id.parse().unwrap(), - worker_world.parse().unwrap(), - ); - let operational_message = OperationalMessage::SpawnMesh(spawn_mesh); - let external_message = wrap_operational_message(operational_message); - let tx = dial(self.proxy_addr.clone()).map_err(|err| anyhow!(err))?; - tx.post(external_message); - Ok(()) + let simulator = self.inner.clone(); + let world_size = self.world_size; + let system_addr = system_addr.parse::().unwrap(); + let worker_world = worker_world.parse::().unwrap(); + let controller_actor_id = controller_actor_id.parse().unwrap(); + + signal_safe_block_on(py, async move { + simulator + .lock() + .await + .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) + .await + .map_err(|err| anyhow!(err))?; + Ok(()) + })? } fn set_training_script_state_running(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Running, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Running) } fn set_training_script_state_waiting(&self) -> PyResult<()> { - set_training_script_state(TrainingScriptState::Waiting, self.proxy_addr.clone()) + set_training_script_state(TrainingScriptState::Waiting) } } pub(crate) fn register_python_bindings(simulator_client_mod: &Bound<'_, PyModule>) -> PyResult<()> { simulator_client_mod.add_class::()?; - let f = wrap_pyfunction!(bootstrap_simulator_backend, simulator_client_mod)?; - f.setattr( - "__module__", - "monarch._rust_bindings.monarch_extension.simulator_client", - )?; - simulator_client_mod.add_function(f)?; Ok(()) } diff --git a/monarch_hyperactor/src/alloc.rs b/monarch_hyperactor/src/alloc.rs index b95b229c..558e20bf 100644 --- a/monarch_hyperactor/src/alloc.rs +++ b/monarch_hyperactor/src/alloc.rs @@ -28,6 +28,7 @@ use hyperactor_mesh::alloc::ProcessAllocator; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAlloc; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocHost; use hyperactor_mesh::alloc::remoteprocess::RemoteProcessAllocInitializer; +use hyperactor_mesh::alloc::sim::SimAllocator; use hyperactor_mesh::shape::Shape; use ndslice::Slice; use pyo3::exceptions::PyRuntimeError; @@ -222,6 +223,53 @@ impl PyLocalAllocator { } } +#[pyclass( + name = "SimAllocatorBase", + module = "monarch._rust_bindings.monarch_hyperactor.alloc", + subclass +)] +pub struct PySimAllocator; + +#[pymethods] +impl PySimAllocator { + #[new] + fn new() -> Self { + PySimAllocator {} + } + + fn allocate_nonblocking<'py>( + &self, + py: Python<'py>, + spec: &PyAllocSpec, + ) -> PyResult> { + // We could use Bound here, and acquire the GIL inside of `future_into_py`, but + // it is rather awkward with the current APIs, and we can anyway support Arc/Mutex + // pretty easily. + let spec = spec.inner.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + SimAllocator + .allocate(spec) + .await + .map(|inner| PyAlloc::new(Box::new(inner))) + .map_err(|e| PyRuntimeError::new_err(format!("{}", e))) + }) + } + + fn allocate_blocking<'py>(&self, py: Python<'py>, spec: &PyAllocSpec) -> PyResult { + // We could use Bound here, and acquire the GIL inside of + // `signal_safe_block_on`, but it is rather awkward with the current + // APIs, and we can anyway support Arc/Mutex pretty easily. + let spec = spec.inner.clone(); + signal_safe_block_on(py, async move { + SimAllocator + .allocate(spec) + .await + .map(|inner| PyAlloc::new(Box::new(inner))) + .map_err(|e| PyRuntimeError::new_err(format!("{:?}", e))) + })? + } +} + #[pyclass( name = "ProcessAllocatorBase", module = "monarch._rust_bindings.monarch_hyperactor.alloc", @@ -474,6 +522,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index ef7522a2..9fd6f055 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -25,7 +25,7 @@ pub enum PyChannelTransport { MetaTls, Local, Unix, - // Sim(/*proxy address:*/ ChannelAddr), TODO kiuk@ add support + // Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support } #[pyclass( @@ -131,9 +131,7 @@ mod tests { #[test] fn test_channel_unsupported_transport() -> PyResult<()> { - let sim_addr = ChannelAddr::any(ChannelTransport::Sim(ChannelAddr::any( - ChannelTransport::Unix, - ))); + let sim_addr = ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let addr = PyChannelAddr { inner: sim_addr }; assert!(addr.get_port().is_err()); diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml index 863b5384..e28372ed 100644 --- a/monarch_simulator/Cargo.toml +++ b/monarch_simulator/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/monarch_simulator:[monarch_simulator,monarch_simulator_lib] +# @generated by autocargo from //monarch/monarch_simulator:monarch_simulator_lib [package] name = "monarch_simulator_lib" @@ -7,14 +7,9 @@ authors = ["Meta"] edition = "2021" license = "BSD-3-Clause" -[[bin]] -name = "monarch_simulator" -path = "src/main.rs" - [dependencies] anyhow = "1.0.98" async-trait = "0.1.86" -clap = { version = "4.5.38", features = ["derive", "env", "string", "unicode", "wrap_help"] } controller = { version = "0.0.0", path = "../controller" } dashmap = { version = "5.5.3", features = ["rayon", "serde"] } futures = { version = "0.3.30", features = ["async-await", "compat"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index e946fd63..40b01d11 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -18,25 +18,16 @@ use hyperactor::ActorRef; use hyperactor::ProcId; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; -use hyperactor::simnet; -use hyperactor::simnet::OperationalMessage; -use hyperactor::simnet::SpawnMesh; -use hyperactor::simnet::simnet_handle; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; use hyperactor_multiprocess::proc_actor::spawn; use hyperactor_multiprocess::system::ServerHandle; use hyperactor_multiprocess::system_actor::ProcLifecycleMode; use monarch_messages::worker::Factory; -use tokio::sync::Mutex; -use tokio::sync::mpsc::UnboundedReceiver; -use tokio::task::JoinHandle; use torch_sys::Layout; use torch_sys::ScalarType; -use crate::simulator::Simulator; use crate::worker::Fabric; use crate::worker::MockWorkerParams; use crate::worker::WorkerActor; @@ -70,15 +61,7 @@ pub async fn spawn_controller( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "controller listen addr: {}, bootstrap addr: {}", @@ -129,15 +112,7 @@ pub async fn spawn_sim_worker( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "worker {} listen addr: {}, bootstrap addr: {}", @@ -184,61 +159,3 @@ pub async fn spawn_sim_worker( .await?; Ok(bootstrap.proc_actor) } - -/// Bootstrap the simulation. Spawns the system, controllers, and workers. -/// Args: -/// system_addr: The address of the system actor. -pub async fn bootstrap( - system_addr: ChannelAddr, - proxy_addr: ChannelAddr, - world_size: usize, -) -> Result> { - // TODO: enable supervision events. - let mut operational_message_rx = simnet::start(system_addr.clone(), proxy_addr, 1000)?; - let simulator = Arc::new(Mutex::new(Simulator::new(system_addr).await?)); - let operational_listener_handle = { - let simulator = simulator.clone(); - tokio::spawn(async move { - handle_operational_message(&mut operational_message_rx, simulator, world_size).await - }) - }; - - Ok(operational_listener_handle) -} - -async fn handle_operational_message( - operational_message_rx: &mut UnboundedReceiver, - simulator: Arc>, - world_size: usize, -) { - while let Some(msg) = operational_message_rx.recv().await { - tracing::info!("received operational message: {:?}", msg); - match msg { - OperationalMessage::SpawnMesh(SpawnMesh { - system_addr, - controller_actor_id, - worker_world, - }) => { - if let Err(e) = simulator - .lock() - .await - .spawn_mesh(system_addr, controller_actor_id, worker_world, world_size) - .await - { - tracing::error!("failed to spawn mesh: {:?}", e); - } - } - OperationalMessage::KillWorld(world_id) => { - if let Err(e) = simulator.lock().await.kill_world(&world_id) { - tracing::error!("failed to kill world: {:?}", e); - } - } - OperationalMessage::SetTrainingScriptState(state) => match simnet_handle() { - Ok(handle) => handle.set_training_script_state(state), - Err(e) => { - tracing::error!("failed to set training script state: {:?}", e); - } - }, - } - } -} diff --git a/monarch_simulator/src/main.rs b/monarch_simulator/src/main.rs deleted file mode 100644 index 12169d91..00000000 --- a/monarch_simulator/src/main.rs +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -//! A binary to launch the simulated Monarch controller along with necessary environment. -use std::process::ExitCode; - -use anyhow::Result; -use clap::Parser; -use hyperactor::channel::ChannelAddr; -use monarch_simulator_lib::bootstrap::bootstrap; - -#[derive(Debug, Parser)] -struct Args { - #[arg(short, long)] - system_addr: ChannelAddr, - #[arg(short, long)] - proxy_addr: ChannelAddr, -} - -const TITLE: &str = r#" -****************************************************** -* * -* ____ ___ __ __ _ _ _ _ _____ ___ ____ * -*/ ___|_ _| \/ | | | | | / \|_ _/ _ \| _ \ * -*\___ \| || |\/| | | | | | / _ \ | || | | | |_) |* -* ___) | || | | | |_| | |___ / ___ \| || |_| | _ < * -*|____/___|_| |_|\___/|_____/_/ \_\_| \___/|_| \_\* -* * -****************************************************** -"#; - -#[tokio::main] -async fn main() -> Result { - eprintln!("{}", TITLE); - hyperactor::initialize_with_current_runtime(); - let args = Args::parse(); - - let system_addr = args.system_addr.clone(); - let proxy_addr = args.proxy_addr.clone(); - tracing::info!("starting Monarch simulation"); - - let operational_listener_handle = bootstrap(system_addr, proxy_addr, 1).await?; - - operational_listener_handle - .await - .expect("simulator exited unexpectedly"); - - Ok(ExitCode::SUCCESS) -} diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 10fa9fd7..a9b9baf7 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -26,13 +26,13 @@ use crate::bootstrap::spawn_system; /// The simulator manages all of the meshes and the system handle. #[derive(Debug)] -pub struct Simulator { +pub struct TensorEngineSimulator { /// A map from world name to actor handles in that world. worlds: HashMap>>, system_handle: ServerHandle, } -impl Simulator { +impl TensorEngineSimulator { pub async fn new(system_addr: ChannelAddr) -> Result { Ok(Self { worlds: HashMap::new(), @@ -92,7 +92,7 @@ impl Simulator { /// IntoFuture allows users to await the handle. The future resolves when /// the simulator itself has all of the actor handles stopped. -impl IntoFuture for Simulator { +impl IntoFuture for TensorEngineSimulator { type Output = (); type IntoFuture = BoxFuture<'static, Self::Output>; @@ -133,19 +133,12 @@ mod tests { #[tracing_test::traced_test] #[tokio::test] async fn test_spawn_and_kill_mesh() { - let proxy = format!("unix!@{}", random_str()); + simnet::start(); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.parse().unwrap(), - 1000, - ) - .unwrap(); - - let system_addr = format!("sim!unix!@system,{}", &proxy) - .parse::() + let system_addr = "sim!unix!@system".parse::().unwrap(); + let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) + .await .unwrap(); - let mut simulator = super::Simulator::new(system_addr.clone()).await.unwrap(); let mut controller_actor_ids = vec![]; let mut worker_actor_ids = vec![]; let n_meshes = 2; diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index 4d662e2d..61114f7c 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -809,12 +809,7 @@ mod tests { #[tokio::test] async fn worker_reduce() -> Result<()> { - simnet::start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let proc = Proc::local(); //let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller")?; let client = proc.attach("client")?; diff --git a/python/monarch/__init__.py b/python/monarch/__init__.py index 7a0b1330..92d1795b 100644 --- a/python/monarch/__init__.py +++ b/python/monarch/__init__.py @@ -113,6 +113,8 @@ "timer": ("monarch.timer", "timer"), "ProcessAllocator": ("monarch._src.actor.allocator", "ProcessAllocator"), "LocalAllocator": ("monarch._src.actor.allocator", "LocalAllocator"), + "SimAllocator": ("monarch._src_actor.allocator", "SimAllocator"), + "ActorFuture": ("monarch.future", "ActorFuture"), "builtins": ("monarch.builtins", "builtins"), } @@ -181,6 +183,8 @@ def __getattr__(name): "timer", "ProcessAllocator", "LocalAllocator", + "SimAllocator", + "ActorFuture", "builtins", ] assert sorted(__all__) == sorted(_public_api) diff --git a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi index 9cea9d93..6b06c274 100644 --- a/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/simulator_client.pyi @@ -13,10 +13,11 @@ class SimulatorClient: It is a client to communicate with the simulator service. Arguments: - - `proxy_addr`: Address of the simulator's proxy server. + - `system_addr`: Address of the system. + - `world_size`: Number of workers in a given mesh. """ - def __init__(self, proxy_addr: str) -> None: ... + def __init__(self, system_addr: str, world_size: int) -> None: ... def kill_world(self, world_name: str) -> None: """ Kill the world with the given name. @@ -51,11 +52,3 @@ class SimulatorClient: backend to resolve a future """ ... - -def bootstrap_simulator_backend( - system_addr: str, proxy_addr: str, world_size: int -) -> None: - """ - Bootstrap the simulator backend on the current process - """ - ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi index 06b1e33d..a58cb667 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi @@ -100,6 +100,26 @@ class LocalAllocatorBase: """ ... +class SimAllocatorBase: + async def allocate_nonblocking(self, spec: AllocSpec) -> Alloc: + """ + Allocate a process according to the provided spec. + + Arguments: + - `spec`: The spec to allocate according to. + """ + ... + + def allocate_blocking(self, spec: AllocSpec) -> Alloc: + """ + Allocate a process according to the provided spec, blocking until an + alloc is returned. + + Arguments: + - `spec`: The spec to allocate according to. + """ + ... + class RemoteAllocatorBase: def __new__( cls, diff --git a/python/monarch/_src/actor/allocator.py b/python/monarch/_src/actor/allocator.py index 2df6dc98..e27c83fe 100644 --- a/python/monarch/_src/actor/allocator.py +++ b/python/monarch/_src/actor/allocator.py @@ -16,6 +16,7 @@ LocalAllocatorBase, ProcessAllocatorBase, RemoteAllocatorBase, + SimAllocatorBase, ) from monarch._src.actor.future import Future @@ -69,6 +70,28 @@ def allocate(self, spec: AllocSpec) -> Future[Alloc]: ) +@final +class SimAllocator(SimAllocatorBase): + """ + An allocator that allocates by spawning actors into the current process using simulated channels for transport + """ + + def allocate(self, spec: AllocSpec) -> Future[Alloc]: + """ + Allocate a process according to the provided spec. + + Arguments: + - `spec`: The spec to allocate according to. + + Returns: + - A future that will be fulfilled when the requested allocation is fulfilled. + """ + return Future( + lambda: self.allocate_nonblocking(spec), + lambda: self.allocate_blocking(spec), + ) + + class RemoteAllocInitializer(abc.ABC): """Subclass-able Python interface for `hyperactor_mesh::alloc::remoteprocess:RemoteProcessAllocInitializer`. diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..c14e2499 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -35,7 +35,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice from monarch._src.actor.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef -from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator +from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator, SimAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor @@ -312,6 +312,33 @@ def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[Pro ) +async def sim_proc_mesh_nonblocking( + *, gpus: Optional[int] = None, hosts: int = 1 +) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = SimAllocator() + alloc = await allocator.allocate(spec) + return await ProcMesh.from_alloc(alloc) + + +def sim_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = SimAllocator() + alloc = allocator.allocate(spec).get() + return ProcMesh.from_alloc(alloc).get() + + +def sim_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: + return Future( + lambda: sim_proc_mesh_nonblocking(gpus=gpus, hosts=hosts), + lambda: sim_proc_mesh_blocking(gpus=gpus, hosts=hosts), + ) + + _BOOTSTRAP_MAIN = "monarch._src.actor.bootstrap_main" diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..78fbcd1d 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,12 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + local_proc_mesh, + proc_mesh, + ProcMesh, + sim_proc_mesh, +) __all__ = [ "Accumulator", @@ -41,5 +46,6 @@ "proc_mesh", "ProcMesh", "send", + "sim_proc_mesh", "ValueMesh", ] diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index b91361bf..2e999190 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -31,7 +31,6 @@ ) from monarch._rust_bindings.monarch_extension.simulator_client import ( # @manual=//monarch/monarch_extension:monarch_extension - bootstrap_simulator_backend, SimulatorClient, ) @@ -59,9 +58,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def sim_mesh( - n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None -) -> List[DeviceMesh]: +def sim_mesh(n_meshes: int, hosts: int, gpus_per_host: int) -> List[DeviceMesh]: """ Creates a single simulated device mesh with the given number of per host. @@ -76,7 +73,6 @@ def sim_mesh( bootstrap: Bootstrap = Bootstrap( n_meshes, mesh_world_state, - proxy_addr=proxy_addr, world_size=hosts * gpus_per_host, ) @@ -181,14 +177,12 @@ def __init__( self, num_meshes: int, mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]], - proxy_addr: Optional[str] = None, world_size: int = 1, ) -> None: """ Bootstraps a SimMesh. Args: num_meshes: int - number of meshes to create. - proxy_addr: Option[str] - the proxy address of the simulation process mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active. """ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later @@ -199,17 +193,11 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = proxy_addr or f"unix!@{_random_id()}-proxy" - self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" + self.bootstrap_addr: str = "sim!unix!@system" + self.client_listen_addr = "sim!unix!@client" + self.client_bootstrap_addr = "sim!unix!@client,unix!@system" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr: str = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr: str = ( - f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" - ) - bootstrap_simulator_backend(self.bootstrap_addr, proxy_addr, world_size) - - self._simulator_client = SimulatorClient(proxy_addr) + self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): mesh_name: str = f"mesh_{i}" controller_world: str = f"{mesh_name}_controller" @@ -235,7 +223,9 @@ def spawn_mesh(self, mesh_world: MeshWorld) -> None: worker_world, controller_id = mesh_world controller_world = controller_id.world_name self._simulator_client.spawn_mesh( - self.bootstrap_addr, f"{controller_world}[0].root", worker_world + self.bootstrap_addr, + f"{controller_world}[0].root", + worker_world, ) diff --git a/python/tests/test_sim_backend.py b/python/tests/test_sim_backend.py index 1ea89cd5..489bc736 100644 --- a/python/tests/test_sim_backend.py +++ b/python/tests/test_sim_backend.py @@ -24,11 +24,8 @@ def local_sim_mesh( # TODO: support multiple gpus in a mesh. gpu_per_host: int = 1, activate: bool = True, - proxy_addr: Optional[str] = None, ) -> Generator[DeviceMesh, None, None]: - dms = sim_mesh( - n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host, proxy_addr=proxy_addr - ) + dms = sim_mesh(n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host) dm = dms[0] try: if activate: