diff --git a/controller/src/lib.rs b/controller/src/lib.rs index a2d1f54d..ce21cc50 100644 --- a/controller/src/lib.rs +++ b/controller/src/lib.rs @@ -624,7 +624,6 @@ mod tests { use hyperactor::RefClient; use hyperactor::channel; use hyperactor::channel::ChannelTransport; - use hyperactor::channel::sim::SimAddr; use hyperactor::clock::Clock; use hyperactor::clock::RealClock; use hyperactor::data::Named; @@ -1559,20 +1558,13 @@ mod tests { #[tokio::test] async fn test_sim_supervision_failure() { // Start system actor. - let system_addr = ChannelAddr::any(ChannelTransport::Unix); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); simnet::simnet_handle() .unwrap() .set_training_script_state(simnet::TrainingScriptState::Waiting); let system_sim_addr = - ChannelAddr::Sim(SimAddr::new(system_addr, proxy_addr.clone()).unwrap()); + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); // Set very long supervision_update_timeout let server_handle = System::serve( system_sim_addr.clone(), @@ -1588,9 +1580,8 @@ mod tests { // Bootstrap the controller let controller_id = id!(controller[0].root); let proc_id = id!(world[0]); - let controller_proc_listen_addr = ChannelAddr::Sim( - SimAddr::new(ChannelAddr::any(ChannelTransport::Unix), proxy_addr).unwrap(), - ); + let controller_proc_listen_addr = + ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let (_, actor_ref) = ControllerActor::bootstrap( controller_id.clone(), diff --git a/hyperactor/src/channel.rs b/hyperactor/src/channel.rs index b98ec1f7..cdf9e662 100644 --- a/hyperactor/src/channel.rs +++ b/hyperactor/src/channel.rs @@ -236,7 +236,7 @@ pub enum ChannelTransport { Local, /// Sim is a simulated channel for testing. - Sim(/*proxy address:*/ ChannelAddr), + Sim(/*simulated transport:*/ Box), /// Transport over unix domain socket. Unix, @@ -368,7 +368,7 @@ impl ChannelAddr { Self::MetaTls(hostname, 0) } ChannelTransport::Local => Self::Local(0), - ChannelTransport::Sim(proxy) => sim::any(proxy), + ChannelTransport::Sim(transport) => sim::any(*transport), // This works because the file will be deleted but we know we have a unique file by this point. ChannelTransport::Unix => Self::Unix(net::unix::SocketAddr::from_str("").unwrap()), } @@ -380,7 +380,7 @@ impl ChannelAddr { Self::Tcp(_) => ChannelTransport::Tcp, Self::MetaTls(_, _) => ChannelTransport::MetaTls, Self::Local(_) => ChannelTransport::Local, - Self::Sim(addr) => ChannelTransport::Sim(addr.proxy().clone()), + Self::Sim(addr) => ChannelTransport::Sim(Box::new(addr.transport())), Self::Unix(_) => ChannelTransport::Unix, } } @@ -637,10 +637,9 @@ mod tests { } for (raw, parsed) in cases_ok.iter().zip(src_ok.clone()).map(|(a, _)| { - let proxy_str = "unix!@proxy_a"; ( - format!("sim!{},{}", a.0, &proxy_str), - ChannelAddr::Sim(SimAddr::new(a.1.clone(), proxy_str.parse().unwrap()).unwrap()), + format!("sim!{}", a.0), + ChannelAddr::Sim(SimAddr::new(a.1.clone()).unwrap()), ) }) { assert_eq!(raw.parse::().unwrap(), parsed); diff --git a/hyperactor/src/channel/sim.rs b/hyperactor/src/channel/sim.rs index be97ca4b..73a663a5 100644 --- a/hyperactor/src/channel/sim.rs +++ b/hyperactor/src/channel/sim.rs @@ -39,25 +39,6 @@ lazy_static! { } static SIM_LINK_BUF_SIZE: usize = 256; -#[derive( - Clone, - Debug, - PartialEq, - Eq, - Serialize, - Deserialize, - Ord, - PartialOrd, - Hash -)] -/// A channel address along with the address of the proxy for the process -pub struct AddressProxyPair { - /// The address. - pub address: ChannelAddr, - /// The address of the proxy for the process - pub proxy: ChannelAddr, -} - /// An address for a simulated channel. #[derive( Clone, @@ -71,36 +52,38 @@ pub struct AddressProxyPair { Hash )] pub struct SimAddr { - src: Option>, + src: Option>, /// The address. addr: Box, - /// The proxy address. - proxy: Box, + /// If source is the client + client: bool, } impl SimAddr { /// Creates a new SimAddr. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. /// Creates a new SimAddr without a source to be served - pub fn new(addr: ChannelAddr, proxy: ChannelAddr) -> Result { - Self::new_impl(None, addr, proxy) + pub fn new(addr: ChannelAddr) -> Result { + Self::new_impl(None, addr, false) } /// Creates a new directional SimAddr meant to convey a channel between two addresses. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. - pub fn new_with_src( - src: AddressProxyPair, - addr: ChannelAddr, - proxy: ChannelAddr, - ) -> Result { - Self::new_impl(Some(Box::new(src)), addr, proxy) + pub fn new_with_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, false) + } + + /// Creates a new directional SimAddr meant to convey a channel between two addresses. + #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. + fn new_with_client_src(src: ChannelAddr, addr: ChannelAddr) -> Result { + Self::new_impl(Some(Box::new(src)), addr, true) } #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. fn new_impl( - src: Option>, + src: Option>, addr: ChannelAddr, - proxy: ChannelAddr, + client: bool, ) -> Result { if let ChannelAddr::Sim(_) = &addr { return Err(SimNetError::InvalidArg(format!( @@ -108,16 +91,10 @@ impl SimAddr { addr ))); } - if let ChannelAddr::Sim(_) = &proxy { - return Err(SimNetError::InvalidArg(format!( - "proxy cannot be a sim address, found {}", - proxy - ))); - } Ok(Self { src, addr: Box::new(addr), - proxy: Box::new(proxy), + client, }) } @@ -126,26 +103,22 @@ impl SimAddr { &self.addr } - /// Returns the proxy address. - pub fn proxy(&self) -> &ChannelAddr { - &self.proxy + /// Returns the source address + pub fn src(&self) -> &Option> { + &self.src } - /// Returns the source address and proxy. - pub fn src(&self) -> &Option> { - &self.src + /// The underlying transport we are simulating + pub fn transport(&self) -> ChannelTransport { + self.addr.transport() } } impl fmt::Display for SimAddr { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match &self.src { - None => write!(f, "{},{}", self.addr, self.proxy), - Some(src) => write!( - f, - "{},{},{},{}", - src.address, src.proxy, self.addr, self.proxy - ), + None => write!(f, "{}", self.addr), + Some(src) => write!(f, "{},{}", src, self.addr), } } } @@ -153,19 +126,15 @@ impl fmt::Display for SimAddr { /// Message Event that can be passed around in the simnet. #[derive(Debug)] pub(crate) struct MessageDeliveryEvent { - src_addr: Option, - dest_addr: AddressProxyPair, + src_addr: Option, + dest_addr: ChannelAddr, data: Serialized, duration_ms: u64, } impl MessageDeliveryEvent { /// Creates a new MessageDeliveryEvent. - pub fn new( - src_addr: Option, - dest_addr: AddressProxyPair, - data: Serialized, - ) -> Self { + pub fn new(src_addr: Option, dest_addr: ChannelAddr, data: Serialized) -> Self { Self { src_addr, dest_addr, @@ -198,16 +167,16 @@ impl Event for MessageDeliveryEvent { "Sending message from {} to {}", self.src_addr .as_ref() - .map_or("unknown".to_string(), |addr| addr.address.to_string()), - self.dest_addr.address.clone() + .map_or("unknown".to_string(), |addr| addr.to_string()), + self.dest_addr.clone() ) } async fn read_simnet_config(&mut self, topology: &Arc>) { if let Some(src_addr) = &self.src_addr { let edge = SimNetEdge { - src: src_addr.address.clone(), - dst: self.dest_addr.address.clone(), + src: src_addr.clone(), + dst: self.dest_addr.clone(), }; self.duration_ms = topology .lock() @@ -232,18 +201,18 @@ pub async fn update_config(config: simnet::NetworkConfig) -> anyhow::Result<(), } /// Returns a simulated channel address that is bound to "any" channel address. -pub(crate) fn any(proxy: ChannelAddr) -> ChannelAddr { +pub(crate) fn any(transport: ChannelTransport) -> ChannelAddr { ChannelAddr::Sim(SimAddr { src: None, - addr: Box::new(ChannelAddr::any(proxy.transport())), - proxy: Box::new(proxy), + addr: Box::new(ChannelAddr::any(transport)), + client: false, }) } /// Parse the sim channel address. It should have two non-sim channel addresses separated by a comma. #[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `ChannelError`. pub fn parse(addr_string: &str) -> Result { - let re = Regex::new(r"([^,]+),([^,]+)(,([^,]+),([^,]+))?$").map_err(|err| { + let re = Regex::new(r"([^,]+)(,([^,]+))?$").map_err(|err| { ChannelError::InvalidAddress(format!("invalid sim address regex: {}", err)) })?; @@ -261,26 +230,19 @@ pub fn parse(addr_string: &str) -> Result { } match parts.len() { - 2 => { + 1 => { let addr = parts[0].parse::()?; - let proxy = parts[1].parse::()?; - Ok(ChannelAddr::Sim(SimAddr::new(addr, proxy)?)) + Ok(ChannelAddr::Sim(SimAddr::new(addr)?)) } - 5 => { + 3 => { let src_addr = parts[0].parse::()?; - let src_proxy = parts[1].parse::()?; - let addr = parts[3].parse::()?; - let proxy = parts[4].parse::()?; - - Ok(ChannelAddr::Sim(SimAddr::new_with_src( - AddressProxyPair { - address: src_addr, - proxy: src_proxy, - }, - addr, - proxy, - )?)) + let addr = parts[2].parse::()?; + Ok(ChannelAddr::Sim(if parts[0] == "client" { + SimAddr::new_with_client_src(src_addr, addr) + } else { + SimAddr::new_with_src(src_addr, addr) + }?)) } _ => Err(ChannelError::InvalidAddress(addr_string.to_string())), } @@ -310,31 +272,22 @@ fn create_egress_sender( Ok(Arc::new(tx)) } -/// Check if the address is outside of the simulation. -#[allow(clippy::result_large_err)] // TODO: Consider reducing the size of `SimNetError`. -fn is_external_addr(addr: &AddressProxyPair) -> anyhow::Result { - Ok(simnet_handle()?.proxy_addr() != &addr.proxy) -} - #[async_trait] -impl Dispatcher for SimDispatcher { +impl Dispatcher for SimDispatcher { async fn send( &self, - _src_addr: Option, - addr: AddressProxyPair, + _src_addr: Option, + addr: ChannelAddr, data: Serialized, ) -> Result<(), SimNetError> { self.dispatchers - .get(&addr.address) + .get(&addr) .ok_or_else(|| { - SimNetError::InvalidNode( - addr.address.to_string(), - anyhow::anyhow!("no dispatcher found"), - ) + SimNetError::InvalidNode(addr.to_string(), anyhow::anyhow!("no dispatcher found")) })? .send(data) .await - .map_err(|err| SimNetError::InvalidNode(addr.address.to_string(), err.into())) + .map_err(|err| SimNetError::InvalidNode(addr.to_string(), err.into())) } } @@ -349,9 +302,10 @@ impl Default for SimDispatcher { #[derive(Debug)] pub(crate) struct SimTx { - src_addr: Option, - dst_addr: AddressProxyPair, + src_addr: Option, + dst_addr: ChannelAddr, status: watch::Receiver, // Default impl. Always reports `Active`. + client: bool, _phantom: PhantomData, } @@ -372,15 +326,14 @@ impl Tx for SimTx { }; match simnet_handle() { Ok(handle) => match &self.src_addr { - Some(src_addr) if src_addr.proxy != *handle.proxy_addr() => handle - .send_scheduled_event(ScheduledEvent { - event: Box::new(MessageDeliveryEvent::new( - self.src_addr.clone(), - self.dst_addr.clone(), - data, - )), - time: SimClock.millis_since_start(RealClock.now()), - }), + Some(_) if self.client => handle.send_scheduled_event(ScheduledEvent { + event: Box::new(MessageDeliveryEvent::new( + self.src_addr.clone(), + self.dst_addr.clone(), + data, + )), + time: SimClock.millis_since_start(RealClock.now()), + }), _ => handle.send_event(Box::new(MessageDeliveryEvent::new( self.src_addr.clone(), self.dst_addr.clone(), @@ -393,7 +346,7 @@ impl Tx for SimTx { } fn addr(&self) -> ChannelAddr { - self.dst_addr.address.clone() + self.dst_addr.clone() } fn status(&self) -> &watch::Receiver { @@ -412,11 +365,9 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE Ok(SimTx { src_addr: dialer, - dst_addr: AddressProxyPair { - address: *addr.addr, - proxy: *addr.proxy, - }, + dst_addr: addr.addr().clone(), status, + client: addr.client, _phantom: PhantomData, }) } @@ -425,13 +376,10 @@ pub(crate) fn dial(addr: SimAddr) -> Result, ChannelE /// The mpsc tx will be used to dispatch messages when it's time while /// the mpsc rx will be used by the above applications to handle received messages /// like any other channel. -/// A sim address has src and dst. Dispatchers are only indexed by dst address. +/// A sim address has a dst and optional src. Dispatchers are only indexed by dst address. pub(crate) fn serve( sim_addr: SimAddr, ) -> anyhow::Result<(ChannelAddr, SimRx)> { - // Serves sim address at sim_addr.src and set up local proxy at sim_addr.src_proxy. - // Reversing the src and dst since the first element in the output tuple is the - // dialing address of this sim channel. So the served address is the dst. let (tx, rx) = mpsc::channel::(SIM_LINK_BUF_SIZE); // Add tx to sender dispatch. SENDER.dispatchers.insert(*sim_addr.addr.clone(), tx); @@ -474,13 +422,7 @@ mod tests { let dst_ok = vec!["[::1]:1234", "tcp!127.0.0.1:8080", "local!123"]; let srcs_ok = vec!["[::2]:1234", "tcp!127.0.0.2:8080", "local!124"]; - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); // TODO: New NodeAdd event should do this for you.. for addr in dst_ok.iter().chain(srcs_ok.iter()) { @@ -490,16 +432,10 @@ mod tests { .bind(addr.parse::().unwrap()) .unwrap(); } - // Messages are transferred internally if only there's a local proxy and the - // dst proxy is the same as local proxy. for (src_addr, dst_addr) in zip(srcs_ok, dst_ok) { let dst_addr = SimAddr::new_with_src( - AddressProxyPair { - address: src_addr.parse::().unwrap(), - proxy: proxy.clone(), - }, + src_addr.parse::().unwrap(), dst_addr.parse::().unwrap(), - proxy.clone(), ) .unwrap(); @@ -517,22 +453,14 @@ mod tests { async fn test_invalid_sim_addr() { let src = "sim!src"; let dst = "sim!dst"; - let src_proxy = "sim!src_proxy"; - let dst_proxy = "sim!dst_proxy"; - let sim_addr = format!("{},{},{},{}", src, src_proxy, dst, dst_proxy); + let sim_addr = format!("{},{}", src, dst); let result = parse(&sim_addr); assert!(matches!(result, Err(ChannelError::InvalidAddress(_)))); - - let dst = "unix!dst".parse::().unwrap(); - let dst_proxy = "sim!unix!a,unix!b".parse::().unwrap(); - let result = SimAddr::new(dst, dst_proxy); - // dst_proxy shouldn't be a sim address. - assert!(matches!(result, Err(SimNetError::InvalidArg(_)))); } #[tokio::test] async fn test_parse_sim_addr() { - let sim_addr = "sim!unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { @@ -540,42 +468,26 @@ mod tests { }; assert!(sim_addr.src().is_none()); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); - let sim_addr = "sim!unix!@src,unix!@proxy,unix!@dst,unix!@proxy"; + let sim_addr = "sim!unix!@src,unix!@dst"; let result = sim_addr.parse(); assert!(result.is_ok()); let ChannelAddr::Sim(sim_addr) = result.unwrap() else { panic!("Expected a sim address"); }; assert!(sim_addr.src().is_some()); - let src_pair = sim_addr.src().clone().unwrap(); - assert_eq!(src_pair.address.to_string(), "unix!@src"); - assert_eq!(src_pair.proxy.to_string(), "unix!@proxy"); assert_eq!(sim_addr.addr().to_string(), "unix!@dst"); - assert_eq!(sim_addr.proxy().to_string(), "unix!@proxy"); } #[tokio::test] async fn test_realtime_frontier() { - let proxy: ChannelAddr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); tokio::time::pause(); - let sim_addr = - SimAddr::new("unix!@dst".parse::().unwrap(), proxy.clone()).unwrap(); + let sim_addr = SimAddr::new("unix!@dst".parse::().unwrap()).unwrap(); let sim_addr_with_src = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@src".parse::().unwrap(), - proxy: proxy.clone(), - }, + "unix!@src".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy.clone(), ) .unwrap(); let (_, mut rx) = sim::serve::<()>(sim_addr.clone()).unwrap(); @@ -612,31 +524,17 @@ mod tests { #[tokio::test] async fn test_client_message_scheduled_realtime() { tokio::time::pause(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + start(); let controller_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: "unix!@controller".parse::().unwrap(), - proxy: proxy_addr.clone(), - }, + "unix!@controller".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let controller_tx = sim::dial::<()>(controller_to_dst.clone()).unwrap(); - let client_to_dst = SimAddr::new_with_src( - AddressProxyPair { - address: ChannelAddr::any(ChannelTransport::Unix), - proxy: ChannelAddr::any(ChannelTransport::Unix), - }, + let client_to_dst = SimAddr::new_with_client_src( + "unix!@client".parse::().unwrap(), "unix!@dst".parse::().unwrap(), - proxy_addr.clone(), ) .unwrap(); let client_tx = sim::dial::<()>(client_to_dst).unwrap(); diff --git a/hyperactor/src/clock.rs b/hyperactor/src/clock.rs index 9c8fcbc4..d9863e01 100644 --- a/hyperactor/src/clock.rs +++ b/hyperactor/src/clock.rs @@ -309,8 +309,7 @@ impl Clock for RealClock { #[cfg(test)] mod tests { - use crate::channel::ChannelAddr; - use crate::channel::ChannelTransport; + use crate::clock::Clock; use crate::clock::SimClock; use crate::simnet; @@ -341,12 +340,7 @@ mod tests { #[tokio::test] async fn test_sim_timeout() { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let res = SimClock .timeout(tokio::time::Duration::from_secs(10), async { SimClock.sleep(tokio::time::Duration::from_secs(5)).await; diff --git a/hyperactor/src/mailbox.rs b/hyperactor/src/mailbox.rs index f661edca..d5207342 100644 --- a/hyperactor/src/mailbox.rs +++ b/hyperactor/src/mailbox.rs @@ -2360,7 +2360,6 @@ mod tests { use crate::channel::ChannelTransport; use crate::channel::dial; use crate::channel::serve; - use crate::channel::sim::AddressProxyPair; use crate::channel::sim::SimAddr; use crate::clock::Clock; use crate::clock::RealClock; @@ -2561,23 +2560,12 @@ mod tests { #[tokio::test] async fn test_sim_client_server() { - let proxy = ChannelAddr::any(channel::ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); - let dst_addr = - SimAddr::new("local!1".parse::().unwrap(), proxy.clone()).unwrap(); + simnet::start(); + let dst_addr = SimAddr::new("local!1".parse::().unwrap()).unwrap(); let src_to_dst = ChannelAddr::Sim( SimAddr::new_with_src( - AddressProxyPair { - address: "local!0".parse::().unwrap(), - proxy: proxy.clone(), - }, + "local!0".parse::().unwrap(), dst_addr.addr().clone(), - dst_addr.proxy().clone(), ) .unwrap(), ); diff --git a/hyperactor/src/simnet.rs b/hyperactor/src/simnet.rs index 9ae27d2d..28b1d1c2 100644 --- a/hyperactor/src/simnet.rs +++ b/hyperactor/src/simnet.rs @@ -26,7 +26,6 @@ use async_trait::async_trait; use dashmap::DashMap; use dashmap::DashSet; use enum_as_inner::EnumAsInner; -use futures::executor::block_on; use serde::Deserialize; use serde::Deserializer; use serde::Serialize; @@ -39,23 +38,16 @@ use tokio::sync::mpsc::UnboundedSender; use tokio::task::JoinError; use tokio::task::JoinHandle; use tokio::time::interval; -use tokio::time::timeout; -use crate as hyperactor; // for macros +// for macros use crate::ActorId; use crate::Mailbox; -use crate::Named; use crate::OncePortRef; -use crate::channel; use crate::channel::ChannelAddr; -use crate::channel::Rx; -use crate::channel::sim::AddressProxyPair; -use crate::channel::sim::MessageDeliveryEvent; use crate::clock::Clock; use crate::clock::RealClock; use crate::clock::SimClock; use crate::data::Serialized; -use crate::mailbox::MessageEnvelope; static HANDLE: OnceLock = OnceLock::new(); @@ -305,10 +297,6 @@ pub enum SimNetError { #[error("timeout after {} ms: {}", .0.as_millis(), .1)] Timeout(Duration, String), - /// External node is trying to connect but proxy is not available. - #[error("proxy not available: {0}")] - ProxyNotAvailable(String), - /// Cannot deliver the message because destination address is missing. #[error("missing destination address")] MissingDestinationAddress, @@ -342,9 +330,6 @@ pub struct SimNetHandle { event_tx: UnboundedSender<(Box, bool, Option)>, config: Arc>, pending_event_count: Arc, - /// Handle to a running proxy server that forwards external messages - /// into the simnet. - proxy_handle: ProxyHandle, /// A receiver to receive simulator operational messages. /// The receiver can be moved out of the simnet handle. training_script_state_tx: tokio::sync::watch::Sender, @@ -416,8 +401,6 @@ impl SimNetHandle { /// Close the simulator, processing pending messages before /// completing the returned future. pub async fn close(&self) -> Result, JoinError> { - // Stop the proxy if there is one. - self.proxy_handle.stop().await?; // Signal the simnet loop to stop self.stop_signal.store(true, Ordering::SeqCst); @@ -462,42 +445,10 @@ impl SimNetHandle { "timeout waiting for received events to be scheduled".to_string(), )) } - - /// Returns the external address of the simnet. - pub fn proxy_addr(&self) -> &ChannelAddr { - &self.proxy_handle.addr - } } pub(crate) type Topology = DashMap; -/// A ProxyMessage is a message that SimNet proxy receives. -/// The message may requests the SimNet to send the payload in the message field from -/// src to dst if addr field exists. -/// Or handle the payload in the message field if addr field is None, indicating that -/// this is a self-handlable message. -#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)] -pub struct ProxyMessage { - sender_addr: Option, - dest_addr: Option, - data: Serialized, -} - -impl ProxyMessage { - /// Creates a new ForwardMessage. - pub fn new( - sender_addr: Option, - dest_addr: Option, - data: Serialized, - ) -> Self { - Self { - sender_addr, - dest_addr, - data, - } - } -} - /// Configure network topology for the simnet pub struct SimNetConfig { // For now, we assume the network is fully connected @@ -510,31 +461,6 @@ pub struct SimNetConfig { /// The network is represented as a graph of nodes. /// The graph is represented as a map of edges. /// The network also has a cloud of inflight messages -/// SimNet also serves a proxy address to receive external traffic. This proxy address can handle -/// [`ProxyMessage`]s and forward the payload from src to dst. -/// -/// Example: -/// In this example, we send a ForwardMessage to the proxy_addr. SimNet will handle the message and -/// forward the payload from src to dst. -/// ```ignore -/// let nw_handle = start("local!0".parse().unwrap(), 1000, true, Some(gen_event_fcn)) -/// .await -/// .unwrap(); -/// let proxy_addr = nw_handle.proxy_addr().clone(); -/// let tx = crate::channel::dial(proxy_addr).unwrap(); -/// let src_to_dst_msg = MessageEnvelope::new_unknown( -/// port_id.clone(), -/// Serialized::serialize(&"hola".to_string()).unwrap(), -/// ); -/// let forward_message = ForwardMessage::new( -/// "unix!@src".parse::().unwrap(), -/// "unix!@dst".parse::().unwrap(), -/// src_to_dst_msg -/// ); -/// let external_message = -/// MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); -/// tx.send(external_message).await.unwrap(); -/// ``` pub struct SimNet { config: Arc>, address_book: DashSet, @@ -545,103 +471,15 @@ pub struct SimNet { pending_event_count: Arc, } -/// A proxy to bridge external nodes and the SimNet. -struct ProxyHandle { - join_handle: Mutex>>, - stop_signal: Arc, - addr: ChannelAddr, -} - -impl ProxyHandle { - /// Starts an proxy server to handle external [`ForwardMessage`]s. It will forward the payload inside - /// the [`ForwardMessage`] from src to dst in the SimNet. - /// Args: - /// proxy_addr: address to listen - /// event_tx: a channel to send events to the SimNet - /// pending_event_count: a counter to keep track of the number of pending events - /// to_event: a function that specifies how to generate an Event from a forward message - async fn start( - proxy_addr: ChannelAddr, - event_tx: UnboundedSender<(Box, bool, Option)>, - pending_event_count: Arc, - ) -> anyhow::Result { - let (addr, mut rx) = channel::serve::(proxy_addr).await?; - tracing::info!("SimNet serving external traffic on {}", &addr); - let stop_signal = Arc::new(AtomicBool::new(false)); - - let join_handle = { - let stop_signal = stop_signal.clone(); - tokio::spawn(async move { - 'outer: loop { - // timeout the wait to enable stop signal checking at least every 100ms. - #[allow(clippy::disallowed_methods)] - if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await { - let proxy_message: ProxyMessage = msg.deserialized().unwrap(); - if let Some(dest_addr) = proxy_message.dest_addr { - let event = Box::new(MessageDeliveryEvent::new( - proxy_message.sender_addr, - dest_addr, - proxy_message.data, - )); - if let Err(e) = event_tx.send((event, true, None)) { - tracing::error!("error sending message to simnet: {:?}", e); - } else { - pending_event_count - .fetch_add(1, std::sync::atomic::Ordering::SeqCst); - } - } - } - if stop_signal.load(Ordering::SeqCst) { - eprintln!("stopping external traffic handler"); - break 'outer; - } - } - }) - }; - Ok(Self { - join_handle: Mutex::new(Some(join_handle)), - stop_signal, - addr, - }) - } - - /// Stop the proxy. - async fn stop(&self) -> Result<(), JoinError> { - self.stop_signal.store(true, Ordering::SeqCst); - let mut guard = self.join_handle.lock().await; - if let Some(handle) = guard.take() { - handle.await - } else { - Ok(()) - } - } -} - /// Starts a sim net. /// Args: -/// private_addr: an internal address to receive operational messages such as NodeJoinEvent /// max_duration_ms: an optional config to override default settings of the network latency -/// enable_record: a flag to enable recording of message delivery records -pub fn start( - private_addr: ChannelAddr, - proxy_addr: ChannelAddr, - max_duration_ms: u64, -) -> anyhow::Result<()> { +pub fn start() { + let max_duration_ms = 1000 * 10; // Construct a topology with one node: the default A. let address_book: DashSet = DashSet::new(); - address_book.insert(private_addr.clone()); let topology = DashMap::new(); - topology.insert( - SimNetEdge { - src: private_addr.clone(), - dst: private_addr, - }, - SimNetEdgeInfo { - latency: Duration::from_millis(1), - }, - ); - let config = Arc::new(Mutex::new(SimNetConfig { topology })); let (training_script_state_tx, training_script_state_rx) = @@ -657,7 +495,7 @@ pub fn start( let stop_signal = stop_signal.clone(); tokio::spawn(async move { - let mut net = SimNet { + SimNet { config, address_book, state: State { @@ -667,30 +505,20 @@ pub fn start( max_latency: Duration::from_millis(max_duration_ms), records: Vec::new(), pending_event_count, - }; - net.run(event_rx, training_script_state_rx, stop_signal) - .await + } + .run(event_rx, training_script_state_rx, stop_signal) + .await }) })); - let proxy_handle = block_on(ProxyHandle::start( - proxy_addr, - event_tx.clone(), - pending_event_count.clone(), - )) - .map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?; - HANDLE.get_or_init(|| SimNetHandle { join_handle, event_tx, config, pending_event_count, - proxy_handle, training_script_state_tx, stop_signal, }); - - Ok(()) } impl SimNet { @@ -961,7 +789,6 @@ mod tests { use tokio::sync::Mutex; use super::*; - use crate::channel::ChannelTransport; use crate::channel::sim::SimAddr; use crate::clock::Clock; use crate::clock::RealClock; @@ -1081,17 +908,7 @@ mod tests { #[tokio::test] async fn test_handle_instantiation() { - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - assert!( - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .is_ok() - ); + start(); simnet_handle().unwrap().close().await.unwrap(); } @@ -1099,12 +916,7 @@ mod tests { async fn test_simnet_config() { // Tests that we can create a simnet, config latency between two node and deliver // the message with configured latency. - start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); let latency = Duration::from_millis(1000); @@ -1121,9 +933,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr.clone()).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); let msg = Box::new(MessageDeliveryEvent::new( alice, bob, @@ -1148,13 +959,7 @@ mod tests { #[tokio::test] async fn test_simnet_debounce() { - let default_addr = "local!0".parse::().unwrap(); - start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let alice = "local!1".parse::().unwrap(); let bob = "local!2".parse::().unwrap(); @@ -1171,10 +976,8 @@ mod tests { .await .unwrap(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - - let alice = SimAddr::new(alice, proxy_addr.clone()).unwrap(); - let bob = SimAddr::new(bob, proxy_addr).unwrap(); + let alice = SimAddr::new(alice).unwrap(); + let bob = SimAddr::new(bob).unwrap(); // Rapidly send 10 messages expecting that each one debounces the processing for _ in 0..10 { @@ -1211,13 +1014,7 @@ mod tests { #[tokio::test] async fn test_sim_dispatch() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + start(); let sender = Some(TestDispatcher::default()); let mut addresses: Vec = Vec::new(); // // Create a simple network of 4 nodes. @@ -1234,11 +1031,10 @@ mod tests { .map(|s| Serialized::serialize(&s.to_string()).unwrap()) .collect(); - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - let addr_0 = SimAddr::new(addresses[0].clone(), proxy_addr.clone()).unwrap(); - let addr_1 = SimAddr::new(addresses[1].clone(), proxy_addr.clone()).unwrap(); - let addr_2 = SimAddr::new(addresses[2].clone(), proxy_addr.clone()).unwrap(); - let addr_3 = SimAddr::new(addresses[3].clone(), proxy_addr.clone()).unwrap(); + let addr_0 = SimAddr::new(addresses[0].clone()).unwrap(); + let addr_1 = SimAddr::new(addresses[1].clone()).unwrap(); + let addr_2 = SimAddr::new(addresses[2].clone()).unwrap(); + let addr_3 = SimAddr::new(addresses[3].clone()).unwrap(); let one = Box::new(MessageDeliveryEvent::new( addr_0.clone(), addr_1.clone(), @@ -1286,20 +1082,20 @@ mod tests { #[tokio::test] async fn test_read_config_from_yaml() { let yaml = r#" -edges: - - src: local!0 - dst: local!1 - metadata: - latency: 1 - - src: local!0 - dst: local!2 - metadata: - latency: 2 - - src: local!1 - dst: local!2 - metadata: - latency: 3 -"#; + edges: + - src: local!0 + dst: local!1 + metadata: + latency: 1 + - src: local!0 + dst: local!2 + metadata: + latency: 2 + - src: local!1 + dst: local!2 + metadata: + latency: 3 + "#; let config = NetworkConfig::from_yaml(yaml).unwrap(); assert_eq!(config.edges.len(), 3); assert_eq!( @@ -1331,74 +1127,9 @@ edges: assert_eq!(config.edges[2].metadata.latency, Duration::from_secs(3)); } - #[cfg(target_os = "linux")] - #[tokio::test] - async fn test_simnet_receive_external_message() { - use tokio::sync::oneshot; - - use crate::PortId; - use crate::channel::Tx; - - let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix); - start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); - let tx = crate::channel::dial(proxy_addr.clone()).unwrap(); - let port_id = PortId(id!(test[0].actor0), 0); - let src_to_dst_msg = Serialized::serialize(&"hola".to_string()).unwrap(); - let src = random_abstract_addr(); - let dst = random_abstract_addr(); - let src_and_proxy = Some(AddressProxyPair { - address: src.clone(), - proxy: proxy_addr.clone(), - }); - let dst_and_proxy = AddressProxyPair { - address: dst.clone(), - proxy: proxy_addr.clone(), - }; - let forward_message = ProxyMessage::new(src_and_proxy, Some(dst_and_proxy), src_to_dst_msg); - let external_message = - MessageEnvelope::new_unknown(port_id, Serialized::serialize(&forward_message).unwrap()); - tx.try_post(external_message, oneshot::channel().0).unwrap(); - // flush doesn't work here because tx.send() delivers the message through real network. - // We have to wait for the message to enter simnet. - RealClock.sleep(Duration::from_millis(1000)).await; - simnet_handle() - .unwrap() - .flush(Duration::from_millis(1000)) - .await - .unwrap(); - let records = simnet_handle().unwrap().close().await; - assert!(records.as_ref().unwrap().len() == 1); - let expected_record = SimulatorEventRecord { - summary: format!("Sending message from {} to {}", src, dst), - start_at: 0, - end_at: 1, - }; - assert_eq!(records.unwrap().first().unwrap(), &expected_record); - } - #[tokio::test] async fn test_sim_sleep() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); - - let default_addr = format!("local!{}", 0) - .parse::() - .unwrap(); - let _ = start( - default_addr.clone(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let start = SimClock.now(); assert_eq!(SimClock.millis_since_start(start), 0); @@ -1411,12 +1142,7 @@ edges: #[tokio::test] async fn test_torch_op() { - start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + start(); let args_string = "1, 2".to_string(); let kwargs_string = "a=2".to_string(); diff --git a/hyperactor_multiprocess/src/ping_pong.rs b/hyperactor_multiprocess/src/ping_pong.rs index 676259cc..4c659153 100644 --- a/hyperactor_multiprocess/src/ping_pong.rs +++ b/hyperactor_multiprocess/src/ping_pong.rs @@ -14,9 +14,7 @@ mod tests { use hyperactor::ActorRef; use hyperactor::Mailbox; use hyperactor::channel::ChannelAddr; - use hyperactor::channel::ChannelTransport; use hyperactor::channel::sim; - use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::id; use hyperactor::reference::Index; @@ -36,16 +34,10 @@ mod tests { #[tokio::test] async fn test_sim_ping_pong() { let system_addr = "local!1".parse::().unwrap(); - let proxy_addr = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy_addr.clone(), - 1000, - ) - .unwrap(); + simnet::start(); - let system_sim_addr = SimAddr::new(system_addr.clone(), proxy_addr.clone()).unwrap(); + let system_sim_addr = SimAddr::new(system_addr.clone()).unwrap(); let server_handle = System::serve( ChannelAddr::Sim(system_sim_addr.clone()), Duration::from_secs(10), @@ -64,21 +56,14 @@ mod tests { let ping_actor_ref = spawn_proc_actor( 2, - proxy_addr.clone(), system_sim_addr.clone(), sys_mailbox.clone(), world_id.clone(), ) .await; - let pong_actor_ref = spawn_proc_actor( - 3, - proxy_addr.clone(), - system_sim_addr, - sys_mailbox.clone(), - world_id.clone(), - ) - .await; + let pong_actor_ref = + spawn_proc_actor(3, system_sim_addr, sys_mailbox.clone(), world_id.clone()).await; // Configure the simulation network. let simnet_config_yaml = r#" @@ -124,7 +109,6 @@ edges: async fn spawn_proc_actor( actor_index: Index, - proxy_addr: ChannelAddr, system_addr: SimAddr, sys_mailbox: Mailbox, world_id: WorldId, @@ -133,19 +117,11 @@ edges: .parse::() .unwrap(); - let proc_sim_addr = SimAddr::new(proc_addr.clone(), proxy_addr.clone()).unwrap(); + let proc_sim_addr = SimAddr::new(proc_addr.clone()).unwrap(); let proc_listen_addr = ChannelAddr::Sim(proc_sim_addr); let proc_id = world_id.proc_id(actor_index); let proc_to_system = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: proc_addr.clone(), - proxy: proxy_addr.clone(), - }, - system_addr.addr().clone(), - system_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(proc_addr.clone(), system_addr.addr().clone()).unwrap(), ); let bootstrap = ProcActor::bootstrap( proc_id, diff --git a/hyperactor_multiprocess/src/system_actor.rs b/hyperactor_multiprocess/src/system_actor.rs index 12d354ea..5202b56f 100644 --- a/hyperactor_multiprocess/src/system_actor.rs +++ b/hyperactor_multiprocess/src/system_actor.rs @@ -38,7 +38,6 @@ use hyperactor::RefClient; use hyperactor::WorldId; use hyperactor::actor::Handler; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor::clock::Clock; use hyperactor::clock::ClockKind; @@ -762,13 +761,9 @@ impl ReportingRouter { ChannelAddr::Sim( SimAddr::new_with_src( // source is the sender - AddressProxyPair { - address: sender_sim_addr.addr().clone(), - proxy: sender_sim_addr.proxy().clone(), - }, + sender_sim_addr.addr().clone(), // dest remains unchanged dest_sim_addr.addr().clone(), - dest_sim_addr.proxy().clone(), ) .unwrap(), ) @@ -2800,19 +2795,11 @@ mod tests { #[tokio::test] async fn test_update_sim_address() { - let proxy = ChannelAddr::any(ChannelTransport::Unix); - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.clone(), - 1000, - ) - .unwrap(); + simnet::start(); let src_id = id!(proc[0].actor); - let src_addr = - ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap(), proxy.clone()).unwrap()); - let dst_addr = - ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap(), proxy.clone()).unwrap()); + let src_addr = ChannelAddr::Sim(SimAddr::new("unix!@src".parse().unwrap()).unwrap()); + let dst_addr = ChannelAddr::Sim(SimAddr::new("unix!@dst".parse().unwrap()).unwrap()); let (_, mut rx) = channel::serve::(src_addr.clone()) .await .unwrap(); @@ -2844,7 +2831,7 @@ mod tests { panic!("Expected sim address"); }; - assert_eq!(addr.src().clone().unwrap().address.to_string(), "unix!@src"); + assert_eq!(addr.src().clone().unwrap().to_string(), "unix!@src"); assert_eq!(addr.addr().to_string(), "unix!@dst"); } } diff --git a/monarch_extension/src/simulation_tools.rs b/monarch_extension/src/simulation_tools.rs index 6913967b..a73855e0 100644 --- a/monarch_extension/src/simulation_tools.rs +++ b/monarch_extension/src/simulation_tools.rs @@ -6,24 +6,16 @@ * LICENSE file in the root directory of this source tree. */ -use hyperactor::channel::ChannelAddr; -use hyperactor::channel::ChannelTransport; use hyperactor::clock::Clock; use hyperactor::clock::SimClock; use hyperactor::simnet; -use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; #[pyfunction] #[pyo3(name = "start_event_loop")] pub fn start_simnet_event_loop(py: Python) -> PyResult> { pyo3_async_runtimes::tokio::future_into_py(py, async move { - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + simnet::start(); Ok(()) }) } diff --git a/monarch_extension/src/simulator_client.rs b/monarch_extension/src/simulator_client.rs index eb93c1c5..64834e06 100644 --- a/monarch_extension/src/simulator_client.rs +++ b/monarch_extension/src/simulator_client.rs @@ -13,7 +13,6 @@ use std::sync::Arc; use anyhow::anyhow; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::ChannelTransport; use hyperactor::simnet; use hyperactor::simnet::TrainingScriptState; use hyperactor::simnet::simnet_handle; @@ -52,29 +51,17 @@ impl SimulatorClient { #[new] fn new(py: Python, system_addr: String, world_size: i32) -> PyResult { signal_safe_block_on(py, async move { - let system_addr = system_addr - .parse::() - .map_err(|err| PyValueError::new_err(err.to_string()))?; - - let ChannelAddr::Sim(system_sim_addr) = &system_addr else { - return Err(PyValueError::new_err(format!( - "bootstrap address should be a sim address: {}", - system_addr - ))); - }; - - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - system_sim_addr.proxy().clone(), - 1000, - ) - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?; + simnet::start(); Ok(Self { inner: Arc::new(Mutex::new( - TensorEngineSimulator::new(system_addr) - .await - .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, + TensorEngineSimulator::new( + system_addr + .parse::() + .map_err(|err| PyValueError::new_err(err.to_string()))?, + ) + .await + .map_err(|err| PyRuntimeError::new_err(err.to_string()))?, )), world_size: world_size as usize, }) diff --git a/monarch_hyperactor/src/channel.rs b/monarch_hyperactor/src/channel.rs index ef7522a2..9fd6f055 100644 --- a/monarch_hyperactor/src/channel.rs +++ b/monarch_hyperactor/src/channel.rs @@ -25,7 +25,7 @@ pub enum PyChannelTransport { MetaTls, Local, Unix, - // Sim(/*proxy address:*/ ChannelAddr), TODO kiuk@ add support + // Sim(/*transport:*/ ChannelTransport), TODO kiuk@ add support } #[pyclass( @@ -131,9 +131,7 @@ mod tests { #[test] fn test_channel_unsupported_transport() -> PyResult<()> { - let sim_addr = ChannelAddr::any(ChannelTransport::Sim(ChannelAddr::any( - ChannelTransport::Unix, - ))); + let sim_addr = ChannelAddr::any(ChannelTransport::Sim(Box::new(ChannelTransport::Unix))); let addr = PyChannelAddr { inner: sim_addr }; assert!(addr.get_port().is_err()); diff --git a/monarch_simulator/Cargo.toml b/monarch_simulator/Cargo.toml index e28372ed..318f34f5 100644 --- a/monarch_simulator/Cargo.toml +++ b/monarch_simulator/Cargo.toml @@ -29,5 +29,4 @@ torch-sys-cuda = { version = "0.0.0", path = "../torch-sys-cuda" } tracing = { version = "0.1.41", features = ["attributes", "valuable"] } [dev-dependencies] -rand = { version = "0.8", features = ["small_rng"] } tracing-test = { version = "0.2.3", features = ["no-env-filter"] } diff --git a/monarch_simulator/src/bootstrap.rs b/monarch_simulator/src/bootstrap.rs index 6ba3000d..40b01d11 100644 --- a/monarch_simulator/src/bootstrap.rs +++ b/monarch_simulator/src/bootstrap.rs @@ -18,7 +18,6 @@ use hyperactor::ActorRef; use hyperactor::ProcId; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; -use hyperactor::channel::sim::AddressProxyPair; use hyperactor::channel::sim::SimAddr; use hyperactor_multiprocess::System; use hyperactor_multiprocess::proc_actor::ProcActor; @@ -62,15 +61,7 @@ pub async fn spawn_controller( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "controller listen addr: {}, bootstrap addr: {}", @@ -121,15 +112,7 @@ pub async fn spawn_sim_worker( panic!("bootstrap_addr must be a SimAddr"); }; let bootstrap_addr = ChannelAddr::Sim( - SimAddr::new_with_src( - AddressProxyPair { - address: listen_addr.clone(), - proxy: bootstrap_addr.proxy().clone(), - }, - bootstrap_addr.addr().clone(), - bootstrap_addr.proxy().clone(), - ) - .unwrap(), + SimAddr::new_with_src(listen_addr.clone(), bootstrap_addr.addr().clone()).unwrap(), ); tracing::info!( "worker {} listen addr: {}, bootstrap addr: {}", diff --git a/monarch_simulator/src/simulator.rs b/monarch_simulator/src/simulator.rs index 97c02b77..aa9f9da6 100644 --- a/monarch_simulator/src/simulator.rs +++ b/monarch_simulator/src/simulator.rs @@ -116,35 +116,14 @@ mod tests { use hyperactor::ProcId; use hyperactor::WorldId; use hyperactor::channel::ChannelAddr; - use hyperactor::channel::ChannelTransport; use hyperactor::simnet; - use rand::Rng; - use rand::distributions::Alphanumeric; - - #[cfg(target_os = "linux")] - fn random_str() -> String { - rand::thread_rng() - .sample_iter(&Alphanumeric) - .take(24) - .map(char::from) - .collect::() - } #[tracing_test::traced_test] #[tokio::test] async fn test_spawn_and_kill_mesh() { - let proxy = format!("unix!@{}", random_str()); - - simnet::start( - ChannelAddr::any(ChannelTransport::Unix), - proxy.parse().unwrap(), - 1000, - ) - .unwrap(); + simnet::start(); - let system_addr = format!("sim!unix!@system,{}", &proxy) - .parse::() - .unwrap(); + let system_addr = "sim!unix!@system".parse::().unwrap(); let mut simulator = super::TensorEngineSimulator::new(system_addr.clone()) .await .unwrap(); diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index d3c90c60..17f9847d 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -753,8 +753,6 @@ mod tests { use anyhow::Result; use futures::future::try_join_all; - use hyperactor::channel::ChannelAddr; - use hyperactor::channel::ChannelTransport; use hyperactor::id; use hyperactor::proc::Proc; use hyperactor::simnet; @@ -808,12 +806,7 @@ mod tests { #[tokio::test] async fn worker_reduce() -> Result<()> { - simnet::start( - "local!0".parse::().unwrap(), - ChannelAddr::any(ChannelTransport::Unix), - 1000, - ) - .unwrap(); + simnet::start(); let proc = Proc::local(); //let (client, controller_ref, mut controller_rx) = proc.attach_actor("controller")?; let client = proc.attach("client")?; diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index 8c304efd..2e999190 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -58,9 +58,7 @@ logger: logging.Logger = logging.getLogger(__name__) -def sim_mesh( - n_meshes: int, hosts: int, gpus_per_host: int, proxy_addr: Optional[str] = None -) -> List[DeviceMesh]: +def sim_mesh(n_meshes: int, hosts: int, gpus_per_host: int) -> List[DeviceMesh]: """ Creates a single simulated device mesh with the given number of per host. @@ -185,7 +183,6 @@ def __init__( Bootstraps a SimMesh. Args: num_meshes: int - number of meshes to create. - proxy_addr: Option[str] - the proxy address of the simulation process mesh_world_state: a state of the meshes. Keys are the MeshWorld and values are boolean indicating if this mesh is active. """ # do a fake call to instantiate ThreadPoolExecutor so we don't block GIL later @@ -196,13 +193,9 @@ def __init__( self._mesh_world_state: Dict[MeshWorld, Optional[DeviceMesh]] = mesh_world_state - proxy_addr = f"unix!@{_random_id()}-proxy" - self.bootstrap_addr: str = f"sim!unix!@system,{proxy_addr}" - client_proxy_addr = f"unix!@{_random_id()}-proxy" - self.client_listen_addr = f"sim!unix!@client,{client_proxy_addr}" - self.client_bootstrap_addr = ( - f"sim!unix!@client,{client_proxy_addr},unix!@system,{proxy_addr}" - ) + self.bootstrap_addr: str = "sim!unix!@system" + self.client_listen_addr = "sim!unix!@client" + self.client_bootstrap_addr = "sim!unix!@client,unix!@system" self._simulator_client = SimulatorClient(self.bootstrap_addr, world_size) for i in range(num_meshes): diff --git a/python/tests/test_sim_backend.py b/python/tests/test_sim_backend.py index 1ea89cd5..489bc736 100644 --- a/python/tests/test_sim_backend.py +++ b/python/tests/test_sim_backend.py @@ -24,11 +24,8 @@ def local_sim_mesh( # TODO: support multiple gpus in a mesh. gpu_per_host: int = 1, activate: bool = True, - proxy_addr: Optional[str] = None, ) -> Generator[DeviceMesh, None, None]: - dms = sim_mesh( - n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host, proxy_addr=proxy_addr - ) + dms = sim_mesh(n_meshes=1, hosts=hosts, gpus_per_host=gpu_per_host) dm = dms[0] try: if activate: