Skip to content

No more operational messages #473

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 10 additions & 156 deletions hyperactor/src/simnet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,9 @@ use serde::Serialize;
use serde::Serializer;
use serde_with::serde_as;
use tokio::sync::Mutex;
use tokio::sync::SetError;
use tokio::sync::mpsc;
use tokio::sync::mpsc::Sender;
use tokio::sync::mpsc::UnboundedReceiver;
use tokio::sync::mpsc::UnboundedSender;
use tokio::sync::mpsc::error::SendError;
use tokio::task::JoinError;
use tokio::task::JoinHandle;
use tokio::time::interval;
Expand All @@ -49,7 +46,6 @@ use crate::ActorId;
use crate::Mailbox;
use crate::Named;
use crate::OncePortRef;
use crate::WorldId;
use crate::channel;
use crate::channel::ChannelAddr;
use crate::channel::Rx;
Expand Down Expand Up @@ -313,18 +309,6 @@ pub enum SimNetError {
#[error("proxy not available: {0}")]
ProxyNotAvailable(String),

/// Unable to send message to the simulator.
#[error(transparent)]
OperationalMessageSendError(#[from] SendError<OperationalMessage>),

/// Setting the operational message sender which is already set.
#[error(transparent)]
OperationalMessageSenderSetError(#[from] SetError<Sender<OperationalMessage>>),

/// Missing OperationalMessageReceiver.
#[error("missing operational message receiver")]
MissingOperationalMessageReceiver,

/// Cannot deliver the message because destination address is missing.
#[error("missing destination address")]
MissingDestinationAddress,
Expand Down Expand Up @@ -361,8 +345,6 @@ pub struct SimNetHandle {
/// Handle to a running proxy server that forwards external messages
/// into the simnet.
proxy_handle: ProxyHandle,
/// A sender to forward simulator operational messages.
operational_message_tx: UnboundedSender<OperationalMessage>,
/// A receiver to receive simulator operational messages.
/// The receiver can be moved out of the simnet handle.
training_script_state_tx: tokio::sync::watch::Sender<TrainingScriptState>,
Expand Down Expand Up @@ -489,82 +471,6 @@ impl SimNetHandle {

pub(crate) type Topology = DashMap<SimNetEdge, SimNetEdgeInfo>;

/// The message to spawn a simulated mesh.
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
pub struct SpawnMesh {
/// The system address.
pub system_addr: ChannelAddr,
/// The controller actor ID.
pub controller_actor_id: ActorId,
/// The worker world.
pub worker_world: WorldId,
}

impl SpawnMesh {
/// Creates a new SpawnMesh.
pub fn new(
system_addr: ChannelAddr,
controller_actor_id: ActorId,
worker_world: WorldId,
) -> Self {
Self {
system_addr,
controller_actor_id,
worker_world,
}
}
}

/// An OperationalMessage is a message to control the simulator to do tasks such as
/// spawning or killing actors.
#[derive(Debug, Serialize, Deserialize, PartialEq, Clone, Named)]
pub enum OperationalMessage {
/// Kill the world with given world_id.
KillWorld(String),
/// Spawn actors in a mesh.
SpawnMesh(SpawnMesh),
/// Update training script state.
SetTrainingScriptState(TrainingScriptState),
}

/// Message Event that can be sent to the simulator.
#[derive(Debug)]
pub struct SimOperation {
/// Sender to send OperationalMessage to the simulator.
operational_message_tx: UnboundedSender<OperationalMessage>,
operational_message: OperationalMessage,
}

impl SimOperation {
/// Creates a new SimOperation.
pub fn new(
operational_message_tx: UnboundedSender<OperationalMessage>,
operational_message: OperationalMessage,
) -> Self {
Self {
operational_message_tx,
operational_message,
}
}
}

#[async_trait]
impl Event for SimOperation {
async fn handle(&self) -> Result<(), SimNetError> {
self.operational_message_tx
.send(self.operational_message.clone())?;
Ok(())
}

fn duration_ms(&self) -> u64 {
0
}

fn summary(&self) -> String {
format!("SimOperation: {:?}", self.operational_message)
}
}

/// A ProxyMessage is a message that SimNet proxy receives.
/// The message may requests the SimNet to send the payload in the message field from
/// src to dst if addr field exists.
Expand Down Expand Up @@ -658,7 +564,6 @@ impl ProxyHandle {
proxy_addr: ChannelAddr,
event_tx: UnboundedSender<(Box<dyn Event>, bool, Option<SimulatorTimeInstant>)>,
pending_event_count: Arc<AtomicUsize>,
operational_message_tx: UnboundedSender<OperationalMessage>,
) -> anyhow::Result<Self> {
let (addr, mut rx) = channel::serve::<MessageEnvelope>(proxy_addr).await?;
tracing::info!("SimNet serving external traffic on {}", &addr);
Expand All @@ -672,26 +577,18 @@ impl ProxyHandle {
#[allow(clippy::disallowed_methods)]
if let Ok(Ok(msg)) = timeout(Duration::from_millis(100), rx.recv()).await {
let proxy_message: ProxyMessage = msg.deserialized().unwrap();
let event: Box<dyn Event> = match proxy_message.dest_addr {
Some(dest_addr) => Box::new(MessageDeliveryEvent::new(
if let Some(dest_addr) = proxy_message.dest_addr {
let event = Box::new(MessageDeliveryEvent::new(
proxy_message.sender_addr,
dest_addr,
proxy_message.data,
)),
None => {
let operational_message: OperationalMessage =
proxy_message.data.deserialized().unwrap();
Box::new(SimOperation::new(
operational_message_tx.clone(),
operational_message,
))
));
if let Err(e) = event_tx.send((event, true, None)) {
tracing::error!("error sending message to simnet: {:?}", e);
} else {
pending_event_count
.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
};

if let Err(e) = event_tx.send((event, true, None)) {
tracing::error!("error sending message to simnet: {:?}", e);
} else {
pending_event_count.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
}
}
if stop_signal.load(Ordering::SeqCst) {
Expand Down Expand Up @@ -729,7 +626,7 @@ pub fn start(
private_addr: ChannelAddr,
proxy_addr: ChannelAddr,
max_duration_ms: u64,
) -> anyhow::Result<UnboundedReceiver<OperationalMessage>> {
) -> anyhow::Result<()> {
// Construct a topology with one node: the default A.
let address_book: DashSet<ChannelAddr> = DashSet::new();
address_book.insert(private_addr.clone());
Expand Down Expand Up @@ -775,14 +672,11 @@ pub fn start(
.await
})
}));
let (operational_message_tx, operational_message_rx) =
mpsc::unbounded_channel::<OperationalMessage>();

let proxy_handle = block_on(ProxyHandle::start(
proxy_addr,
event_tx.clone(),
pending_event_count.clone(),
operational_message_tx.clone(),
))
.map_err(|err| SimNetError::ProxyNotAvailable(err.to_string()))?;

Expand All @@ -792,12 +686,11 @@ pub fn start(
config,
pending_event_count,
proxy_handle,
operational_message_tx,
training_script_state_tx,
stop_signal,
});

Ok(operational_message_rx)
Ok(())
}

impl SimNet {
Expand Down Expand Up @@ -1488,45 +1381,6 @@ edges:
assert_eq!(records.unwrap().first().unwrap(), &expected_record);
}

#[cfg(target_os = "linux")]
#[tokio::test]
async fn test_simnet_receive_operational_message() {
use tokio::sync::oneshot;

use crate::PortId;
use crate::channel::Tx;

let proxy_addr = ChannelAddr::any(channel::ChannelTransport::Unix);
let mut operational_message_rx = start(
ChannelAddr::any(ChannelTransport::Unix),
proxy_addr.clone(),
1000,
)
.unwrap();
let tx = crate::channel::dial(proxy_addr.clone()).unwrap();
let port_id = PortId(id!(test[0].actor0), 0);
let spawn_mesh = SpawnMesh {
system_addr: "unix!@system".parse().unwrap(),
controller_actor_id: id!(controller_world[0].actor),
worker_world: id!(worker_world),
};
let operational_message = OperationalMessage::SpawnMesh(spawn_mesh.clone());
let serialized_operational_message = Serialized::serialize(&operational_message).unwrap();
let proxy_message = ProxyMessage::new(None, None, serialized_operational_message);
let serialized_proxy_message = Serialized::serialize(&proxy_message).unwrap();
let external_message = MessageEnvelope::new_unknown(port_id, serialized_proxy_message);

// Send the message to the simnet.
tx.try_post(external_message, oneshot::channel().0).unwrap();
// flush doesn't work here because tx.send() delivers the message through real network.
// We have to wait for the message to enter simnet.
RealClock.sleep(Duration::from_millis(1000)).await;
let received_operational_message = operational_message_rx.recv().await.unwrap();

// Check the received message.
assert_eq!(received_operational_message, operational_message);
}

#[tokio::test]
async fn test_sim_sleep() {
start(
Expand Down
Loading
Loading