From 777f9654cdcfc0cf0fc29b1c9ecd2bae4f24617f Mon Sep 17 00:00:00 2001 From: Jun Li Date: Fri, 11 Jul 2025 18:53:26 -0700 Subject: [PATCH] Python actor mesh supervision support Summary: This diff exposes the Rust ActorMesh supervision API to Python ActorMesh. It also wires the supervision events to endpoint calls, including *call()/call_one()/choose()/stream()*. So when a supervision error happens, all inflight calls will get notified, new calls will be failed. The current diff fails the whole mesh when any actor fails in the mesh. A followup diff will provide more granular management here, such that we may only fail the calls that actually expect reply from the failed actors. Reviewed By: colin2328 Differential Revision: D77434080 --- hyperactor_mesh/src/actor_mesh.rs | 46 ++- hyperactor_mesh/src/proc_mesh.rs | 23 +- hyperactor_mesh/src/shared_cell.rs | 10 + monarch_extension/src/lib.rs | 5 + monarch_hyperactor/src/actor_mesh.rs | 325 +++++++++++++++++- monarch_hyperactor/src/lib.rs | 1 + monarch_hyperactor/src/mailbox.rs | 16 +- monarch_hyperactor/src/proc_mesh.rs | 113 +++--- monarch_hyperactor/src/supervision.rs | 25 ++ .../monarch_hyperactor/actor_mesh.pyi | 92 ++++- .../monarch_hyperactor/supervision.pyi | 15 + python/monarch/_src/actor/actor_mesh.py | 112 ++++-- python/monarch/common/remote.py | 8 +- python/monarch/mesh_controller.py | 8 +- python/tests/test_actor_error.py | 118 ++++++- 15 files changed, 797 insertions(+), 120 deletions(-) create mode 100644 monarch_hyperactor/src/supervision.rs create mode 100644 python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi diff --git a/hyperactor_mesh/src/actor_mesh.rs b/hyperactor_mesh/src/actor_mesh.rs index bcf4e8cb..81ca8bae 100644 --- a/hyperactor_mesh/src/actor_mesh.rs +++ b/hyperactor_mesh/src/actor_mesh.rs @@ -179,7 +179,9 @@ pub struct RootActorMesh<'a, A: RemoteActor> { proc_mesh: ProcMeshRef<'a>, name: String, pub(crate) ranks: Vec>, // temporary until we remove `ArcActorMesh`. - actor_supervision_rx: mpsc::UnboundedReceiver, + // The receiver of supervision events. It is None if it has been transferred to + // an actor event observer. + actor_supervision_rx: Option>, } impl<'a, A: RemoteActor> RootActorMesh<'a, A> { @@ -193,7 +195,7 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { proc_mesh: ProcMeshRef::Borrowed(proc_mesh), name, ranks, - actor_supervision_rx, + actor_supervision_rx: Some(actor_supervision_rx), } } @@ -207,7 +209,7 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { proc_mesh: ProcMeshRef::Shared(Box::new(proc_mesh)), name, ranks, - actor_supervision_rx, + actor_supervision_rx: Some(actor_supervision_rx), } } @@ -216,21 +218,35 @@ impl<'a, A: RemoteActor> RootActorMesh<'a, A> { self.proc_mesh.client().open_port() } - /// An event stream of proc events. Each ProcMesh can produce only one such + /// An event stream of actor events. Each RootActorMesh can produce only one such /// stream, returning None after the first call. + pub fn events(&mut self) -> Option { + self.actor_supervision_rx + .take() + .map(|actor_supervision_rx| ActorSupervisionEvents { + actor_supervision_rx, + mesh_id: self.id(), + }) + } +} + +/// Supervision event stream for actor mesh. It emits actor supervision events. +pub struct ActorSupervisionEvents { + // The receiver of supervision events from proc mesh. + actor_supervision_rx: mpsc::UnboundedReceiver, + // The name of the actor mesh. + mesh_id: ActorMeshId, +} + +impl ActorSupervisionEvents { pub async fn next(&mut self) -> Option { let result = self.actor_supervision_rx.recv().await; - match result.as_ref() { - Some(event) => { - tracing::debug!("Received supervision event: {event:?}"); - } - None => { - tracing::info!( - "Supervision stream for actor mesh {} was closed!", - self.name - ); - } - }; + if result.is_none() { + tracing::info!( + "supervision stream for actor mesh {:?} was closed!", + self.mesh_id + ); + } result } } diff --git a/hyperactor_mesh/src/proc_mesh.rs b/hyperactor_mesh/src/proc_mesh.rs index a68215f6..de8ef148 100644 --- a/hyperactor_mesh/src/proc_mesh.rs +++ b/hyperactor_mesh/src/proc_mesh.rs @@ -490,7 +490,7 @@ impl ProcMesh { } /// Proc lifecycle events. -#[derive(Debug)] +#[derive(Debug, Clone)] pub enum ProcEvent { /// The proc of the given rank was stopped with the provided reason. Stopped(usize, ProcStopReason), @@ -560,17 +560,15 @@ impl ProcEvents { }; // transmit to the correct root actor mesh. { - let Some(tx) = self.actor_event_router.get(actor_id.name()) else { + if let Some(tx) = self.actor_event_router.get(actor_id.name()) { + if tx.send(event).is_err() { + tracing::warn!("unable to transmit supervision event to actor {}", actor_id); + } + } else { tracing::warn!("received supervision event for unregistered actor {}", actor_id); - continue; - }; - let Ok(_) = tx.send(event) else { - tracing::warn!("unable to transmit supervision event to actor {}", actor_id); - continue; - }; + } } - // TODO: Actor supervision events need to be wired to the frontend. - // TODO: This event should be handled by the proc mesh if unhandled by actor mesh. + // Send this event to Python proc mesh to keep its health status up to date. break Some(ProcEvent::Crashed(*rank, actor_status.to_string())) } } @@ -777,6 +775,7 @@ mod tests { let mut events = mesh.events().unwrap(); let mut actors = mesh.spawn::("failing", &()).await.unwrap(); + let mut actor_events = actors.events().unwrap(); actors .cast( @@ -790,7 +789,7 @@ mod tests { ProcEvent::Crashed(0, reason) if reason.contains("failmonkey") ); - let event = actors.next().await.unwrap(); + let event = actor_events.next().await.unwrap(); assert_matches!(event.actor_status(), ActorStatus::Failed(_)); assert_eq!(event.actor_id().1, "failing".to_string()); assert_eq!(event.actor_id().2, 0); @@ -806,6 +805,6 @@ mod tests { ); assert!(events.next().await.is_none()); - assert!(actors.next().await.is_none()); + assert!(actor_events.next().await.is_none()); } } diff --git a/hyperactor_mesh/src/shared_cell.rs b/hyperactor_mesh/src/shared_cell.rs index 6f180bfd..6cdabf9b 100644 --- a/hyperactor_mesh/src/shared_cell.rs +++ b/hyperactor_mesh/src/shared_cell.rs @@ -156,6 +156,16 @@ impl SharedCell { SharedCellRef::from(self.inner.clone().try_read_owned()?) } + /// Execute given closure with write access to the underlying data. If the cell is empty, returns an error. + pub async fn with_mut(&self, f: F) -> Result + where + F: FnOnce(&mut T) -> R, + { + let mut inner = self.inner.write(true).await; + let value = inner.value.as_mut().ok_or(EmptyCellError {})?; + Ok(f(value)) + } + /// Take the item out of the cell, leaving it in an unusable state. pub async fn take(&self) -> Result { let mut inner = self.inner.write(true).await; diff --git a/monarch_extension/src/lib.rs b/monarch_extension/src/lib.rs index 73bab55d..180c6b31 100644 --- a/monarch_extension/src/lib.rs +++ b/monarch_extension/src/lib.rs @@ -78,6 +78,11 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> { "monarch_hyperactor.selection", )?)?; + monarch_hyperactor::supervision::register_python_bindings(&get_or_add_new_module( + module, + "monarch_hyperactor.supervision", + )?)?; + #[cfg(feature = "tensor_engine")] { client::register_python_bindings(&get_or_add_new_module( diff --git a/monarch_hyperactor/src/actor_mesh.rs b/monarch_hyperactor/src/actor_mesh.rs index 0fcd0c6a..b2bef2b9 100644 --- a/monarch_hyperactor/src/actor_mesh.rs +++ b/monarch_hyperactor/src/actor_mesh.rs @@ -6,35 +6,99 @@ * LICENSE file in the root directory of this source tree. */ +use std::sync::Arc; + use hyperactor::ActorRef; +use hyperactor::mailbox::OncePortReceiver; +use hyperactor::mailbox::PortReceiver; +use hyperactor::supervision::ActorSupervisionEvent; use hyperactor_mesh::Mesh; use hyperactor_mesh::RootActorMesh; use hyperactor_mesh::actor_mesh::ActorMesh; +use hyperactor_mesh::actor_mesh::ActorSupervisionEvents; use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; +use pyo3::exceptions::PyEOFError; use pyo3::exceptions::PyException; use pyo3::exceptions::PyRuntimeError; +use pyo3::exceptions::PyValueError; use pyo3::prelude::*; +use pyo3::types::PyType; +use tokio::sync::Mutex; use crate::actor::PythonActor; use crate::actor::PythonMessage; use crate::mailbox::PyMailbox; +use crate::mailbox::PythonOncePortReceiver; +use crate::mailbox::PythonPortReceiver; use crate::proc::PyActorId; use crate::proc_mesh::Keepalive; +use crate::runtime::signal_safe_block_on; use crate::selection::PySelection; use crate::shape::PyShape; +use crate::supervision::SupervisionError; #[pyclass( name = "PythonActorMesh", module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" )] pub struct PythonActorMesh { - pub(super) inner: SharedCell>, - pub client: PyMailbox, - pub(super) _keepalive: Keepalive, + inner: SharedCell>, + client: PyMailbox, + _keepalive: Keepalive, + unhealthy_event: Arc>>, + user_monitor_sender: tokio::sync::broadcast::Sender, + monitor: tokio::task::JoinHandle<()>, } impl PythonActorMesh { + /// Create a new [`PythonActorMesh`] with a monitor that will observe supervision + /// errors for this mesh, and update its state properly. + pub(crate) fn monitored( + inner: SharedCell>, + client: PyMailbox, + keepalive: Keepalive, + events: ActorSupervisionEvents, + ) -> Self { + let (user_monitor_sender, _) = tokio::sync::broadcast::channel::(1); + let unhealthy_event = Arc::new(std::sync::Mutex::new(None)); + let monitor = tokio::spawn(Self::actor_mesh_monitor( + events, + user_monitor_sender.clone(), + unhealthy_event.clone(), + )); + Self { + inner, + client, + _keepalive: keepalive, + unhealthy_event, + user_monitor_sender, + monitor, + } + } + + /// Monitor of the actor mesh. It processes supervision errors for the mesh, and keeps mesh + /// health state up to date. + async fn actor_mesh_monitor( + mut events: ActorSupervisionEvents, + user_sender: tokio::sync::broadcast::Sender, + unhealthy_event: Arc>>, + ) { + loop { + let event = events.next().await; + if let Some(event) = event { + let mut inner_unhealthy_event = unhealthy_event.lock().unwrap(); + *inner_unhealthy_event = Some(event.clone()); + + // Ignore the sender error when there is no receiver, which happens when there + // is no active requests to this mesh. + let _ = user_sender.send(event); + } else { + break; + } + } + } + fn try_inner(&self) -> PyResult>> { self.inner .borrow() @@ -45,12 +109,33 @@ impl PythonActorMesh { #[pymethods] impl PythonActorMesh { fn cast(&self, selection: &PySelection, message: &PythonMessage) -> PyResult<()> { + let unhealthy_event = self + .unhealthy_event + .lock() + .expect("failed to acquire unhealthy_event lock"); + if let Some(ref event) = *unhealthy_event { + return Err(PyRuntimeError::new_err(format!( + "actor mesh is unhealthy with reason: {:?}", + event + ))); + } + self.try_inner()? .cast(selection.inner().clone(), message.clone()) .map_err(|err| PyException::new_err(err.to_string()))?; Ok(()) } + fn get_supervision_event(&self) -> PyResult> { + let unhealthy_event = self + .unhealthy_event + .lock() + .expect("failed to acquire unhealthy_event lock"); + Ok(unhealthy_event + .as_ref() + .map(|event| PyActorSupervisionEvent::from(event.clone()))) + } + // Consider defining a "PythonActorRef", which carries specifically // a reference to python message actors. fn get(&self, rank: usize) -> PyResult> { @@ -61,6 +146,16 @@ impl PythonActorMesh { .map(PyActorId::from)) } + // Start monitoring the actor mesh by subscribing to its supervision events. For each supervision + // event, it is consumed by PythonActorMesh first, then gets sent to the monitor for user to consume. + fn monitor<'py>(&self, py: Python<'py>) -> PyResult { + let receiver = self.user_monitor_sender.subscribe(); + let monitor_instance = PyActorMeshMonitor { + receiver: SharedCell::from(Mutex::new(receiver)), + }; + Ok(monitor_instance.into_py(py)) + } + #[getter] pub fn client(&self) -> PyMailbox { self.client.clone() @@ -71,7 +166,231 @@ impl PythonActorMesh { Ok(PyShape::from(self.try_inner()?.shape().clone())) } } + +impl Drop for PythonActorMesh { + fn drop(&mut self) { + self.monitor.abort(); + } +} + +#[pyclass( + name = "ActorMeshMonitor", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +pub struct PyActorMeshMonitor { + receiver: SharedCell>>, +} + +#[pymethods] +impl PyActorMeshMonitor { + fn __aiter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> { + slf + } + + pub fn __anext__(&self, py: Python<'_>) -> PyResult { + let receiver = self.receiver.clone(); + Ok( + pyo3_async_runtimes::tokio::future_into_py( + py, + async move { get_next(receiver).await }, + )? + .into(), + ) + } +} + +impl PyActorMeshMonitor { + pub async fn next(&self) -> PyResult { + get_next(self.receiver.clone()).await + } +} + +impl Clone for PyActorMeshMonitor { + fn clone(&self) -> Self { + Self { + receiver: self.receiver.clone(), + } + } +} + +async fn get_next( + receiver: SharedCell>>, +) -> PyResult { + let receiver = receiver.clone(); + + let receiver = receiver + .borrow() + .map_err(|_| PyRuntimeError::new_err("`Actor mesh receiver` is shutdown"))?; + let mut receiver = receiver.lock().await; + let event = receiver.recv().await.unwrap(); + + Ok(Python::with_gil(|py| { + PyActorSupervisionEvent::from(event).into_py(py) + })) +} + +#[pyclass( + name = "MonitoredPortReceiver", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +pub(super) struct MonitoredPythonPortReceiver { + inner: Arc>>, + monitor: PyActorMeshMonitor, +} + +#[pymethods] +impl MonitoredPythonPortReceiver { + #[classmethod] + fn new<'py>( + _cls: &Bound<'_, PyType>, + _py: Python<'py>, + receiver: &PythonPortReceiver, + monitor: &PyActorMeshMonitor, + ) -> PyResult { + let inner = receiver.inner(); + Ok(Python::with_gil(|py| { + MonitoredPythonPortReceiver { + inner, + monitor: monitor.clone(), + } + .into_py(py) + })) + } + + fn recv<'py>(&mut self, py: Python<'py>) -> PyResult> { + let receiver = self.inner.clone(); + let monitor = self.monitor.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + let mut receiver = receiver.lock().await; + tokio::select! { + result = receiver.recv() => { + result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) + } + event = monitor.next() => { + Err(PyErr::new::(format!("supervision error: {:?}", event.unwrap()))) + } + } + }) + } + + fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult { + let receiver = self.inner.clone(); + let monitor = self.monitor.clone(); + signal_safe_block_on(py, async move { + let mut receiver = receiver.lock().await; + tokio::select! { + result = receiver.recv() => { + result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) + } + event = monitor.next() => { + Err(PyErr::new::(format!("supervision error: {:?}", event.unwrap()))) + } + } + })? + } +} + +#[pyclass( + name = "MonitoredOncePortReceiver", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +pub(super) struct MonitoredPythonOncePortReceiver { + inner: Arc>>>, + monitor: PyActorMeshMonitor, +} + +#[pymethods] +impl MonitoredPythonOncePortReceiver { + #[classmethod] + fn new<'py>( + _cls: &Bound<'_, PyType>, + _py: Python<'py>, + receiver: &PythonOncePortReceiver, + monitor: &PyActorMeshMonitor, + ) -> PyResult { + let inner = receiver.inner(); + Ok(Python::with_gil(|py| { + MonitoredPythonOncePortReceiver { + inner, + monitor: monitor.clone(), + } + .into_py(py) + })) + } + + fn recv<'py>(&mut self, py: Python<'py>) -> PyResult> { + let Some(receiver) = self.inner.lock().unwrap().take() else { + return Err(PyErr::new::("OncePort is already used")); + }; + let monitor = self.monitor.clone(); + pyo3_async_runtimes::tokio::future_into_py(py, async move { + tokio::select! { + result = receiver.recv() => { + result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) + } + event = monitor.next() => { + Err(PyErr::new::(format!("supervision error: {:?}", event.unwrap()))) + } + } + }) + } + + fn blocking_recv<'py>(&mut self, py: Python<'py>) -> PyResult { + let Some(receiver) = self.inner.lock().unwrap().take() else { + return Err(PyErr::new::("OncePort is already used")); + }; + let monitor = self.monitor.clone(); + signal_safe_block_on(py, async move { + tokio::select! { + result = receiver.recv() => { + result.map_err(|err| PyErr::new::(format!("port closed: {}", err))) + } + event = monitor.next() => { + Err(PyErr::new::(format!("supervision error: {:?}", event.unwrap()))) + } + } + })? + } +} + +#[pyclass( + name = "ActorSupervisionEvent", + module = "monarch._rust_bindings.monarch_hyperactor.actor_mesh" +)] +pub struct PyActorSupervisionEvent { + /// Actor ID of the actor where supervision event originates from. + #[pyo3(get)] + actor_id: PyActorId, + /// String representation of the actor status. + /// TODO(T230628951): make it an enum or a struct for easier consumption. + #[pyo3(get)] + actor_status: String, +} + +#[pymethods] +impl PyActorSupervisionEvent { + fn __repr__(&self) -> PyResult { + Ok(format!( + "", + self.actor_id, self.actor_status + )) + } +} + +impl From for PyActorSupervisionEvent { + fn from(event: ActorSupervisionEvent) -> Self { + PyActorSupervisionEvent { + actor_id: event.actor_id().clone().into(), + actor_status: event.actor_status().to_string(), + } + } +} + pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/lib.rs b/monarch_hyperactor/src/lib.rs index e5f8c6ef..175e10bd 100644 --- a/monarch_hyperactor/src/lib.rs +++ b/monarch_hyperactor/src/lib.rs @@ -21,6 +21,7 @@ pub mod proc_mesh; pub mod runtime; pub mod selection; pub mod shape; +pub mod supervision; pub mod telemetry; #[cfg(fbcode_build)] diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index f8321b1f..458ea936 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -81,7 +81,7 @@ impl PyMailbox { let receiver = Py::new( py, PythonOncePortReceiver { - inner: std::sync::Mutex::new(Some(receiver)), + inner: Arc::new(std::sync::Mutex::new(Some(receiver))), }, )?; PyTuple::new(py, vec![handle.into_any(), receiver.into_any()]) @@ -371,6 +371,12 @@ impl PythonPortReceiver { } } +impl PythonPortReceiver { + pub fn inner(&self) -> Arc>> { + Arc::clone(&self.inner) + } +} + #[derive(Debug)] #[pyclass( name = "UndeliverableMessageEnvelope", @@ -495,7 +501,7 @@ impl From> for PythonOncePortRef { module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] pub(super) struct PythonOncePortReceiver { - inner: std::sync::Mutex>>, + inner: Arc>>>, } #[pymethods] @@ -521,6 +527,12 @@ impl PythonOncePortReceiver { } } +impl PythonOncePortReceiver { + pub fn inner(&self) -> Arc>>> { + Arc::clone(&self.inner) + } +} + #[derive( Clone, Serialize, diff --git a/monarch_hyperactor/src/proc_mesh.rs b/monarch_hyperactor/src/proc_mesh.rs index 34bf6d4d..c6adca18 100644 --- a/monarch_hyperactor/src/proc_mesh.rs +++ b/monarch_hyperactor/src/proc_mesh.rs @@ -43,6 +43,7 @@ use crate::alloc::PyAlloc; use crate::mailbox::PyMailbox; use crate::runtime::signal_safe_block_on; use crate::shape::PyShape; +use crate::supervision::SupervisionError; // A wrapper around `ProcMesh` which keeps track of all `RootActorMesh`s that it spawns. pub struct TrackedProcMesh { @@ -114,8 +115,9 @@ pub struct PyProcMesh { pub inner: SharedCell, keepalive: Keepalive, proc_events: SharedCell>, - stop_monitor_sender: mpsc::Sender, - user_monitor_registered: AtomicBool, + user_monitor_receiver: SharedCell>>, + user_monitor_registered: Arc, + unhealthy_event: Arc>>, } fn allocate_proc_mesh<'py>(py: Python<'py>, alloc: &PyAlloc) -> PyResult> { @@ -155,24 +157,28 @@ fn allocate_proc_mesh_blocking<'py>(py: Python<'py>, alloc: &PyAlloc) -> PyResul } impl PyProcMesh { - /// Create a new [`PyProcMesh`] with a monitor that crashes the - /// process on any proc failure. + /// Create a new [`PyProcMesh`] with self health status monitoring. fn monitored(mut proc_mesh: ProcMesh, world_id: WorldId) -> Self { - let (sender, abort_receiver) = mpsc::channel::(1); let proc_events = SharedCell::from(Mutex::new(proc_mesh.events().unwrap())); + let (user_sender, user_receiver) = mpsc::unbounded_channel::(); + let user_monitor_registered = Arc::new(AtomicBool::new(false)); + let unhealthy_event = Arc::new(Mutex::new(None)); let monitor = tokio::spawn(Self::default_proc_mesh_monitor( proc_events .borrow() .expect("borrowing immediately after creation"), world_id, - abort_receiver, + user_sender, + user_monitor_registered.clone(), + unhealthy_event.clone(), )); Self { inner: SharedCell::from(TrackedProcMesh::from(proc_mesh)), keepalive: Keepalive::new(monitor), proc_events, - stop_monitor_sender: sender, - user_monitor_registered: AtomicBool::new(false), + user_monitor_receiver: SharedCell::from(Mutex::new(user_receiver)), + user_monitor_registered: user_monitor_registered.clone(), + unhealthy_event, } } @@ -181,30 +187,34 @@ impl PyProcMesh { async fn default_proc_mesh_monitor( events: SharedCellRef>, world_id: WorldId, - mut abort_receiver: mpsc::Receiver, + user_sender: mpsc::UnboundedSender, + user_monitor_registered: Arc, + unhealthy_event: Arc>>, ) { - let mut proc_events = events.lock().await; loop { + let mut proc_events = events.lock().await; tokio::select! { event = proc_events.next() => { if let Some(event) = event { + let mut inner_unhealthy_event = unhealthy_event.lock().await; + *inner_unhealthy_event = Some(event.clone()); + match event { // A graceful stop should not be cause for alarm, but // everything else should be considered a crash. ProcEvent::Stopped(_, ProcStopReason::Stopped) => continue, event => { eprintln!("ProcMesh {}: {}", world_id, event); - std::process::exit(1) + if user_monitor_registered.load(std::sync::atomic::Ordering::SeqCst) { + if user_sender.send(event).is_err() { + eprintln!("failed to deliver the supervision event to user"); + } + } } } } } - _ = async { - tokio::select! { - _ = events.preempted() => (), - _ = abort_receiver.recv() => (), - } - } => { + _ = events.preempted() => { // The default monitor is aborted, this happens when user takes over // the monitoring responsibility. eprintln!("stop default supervision monitor for ProcMesh {}", world_id); @@ -247,17 +257,28 @@ impl PyProcMesh { name: String, actor: &Bound<'py, PyType>, ) -> PyResult> { + let unhealthy_event: Arc>> = self.unhealthy_event.clone(); let pickled_type = PickledPyObject::pickle(actor.as_any())?; let proc_mesh = self.try_inner()?; let keepalive = self.keepalive.clone(); pyo3_async_runtimes::tokio::future_into_py(py, async move { + let unhealthy_event = unhealthy_event.lock().await; + if let Some(unhealthy_event) = unhealthy_event.clone() { + return Err(SupervisionError::new_err(format!( + "proc mesh is stopped with reason: {:?}", + unhealthy_event + ))); + } + let mailbox = proc_mesh.client().clone(); let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; - let python_actor_mesh = PythonActorMesh { - inner: actor_mesh, - client: PyMailbox { inner: mailbox }, - _keepalive: keepalive, - }; + let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); + let python_actor_mesh = PythonActorMesh::monitored( + actor_mesh, + PyMailbox { inner: mailbox }, + keepalive, + actor_events, + ); Python::with_gil(|py| python_actor_mesh.into_py_any(py)) }) } @@ -268,17 +289,28 @@ impl PyProcMesh { name: String, actor: &Bound<'py, PyType>, ) -> PyResult { + let unhealthy_event = self.unhealthy_event.clone(); let pickled_type = PickledPyObject::pickle(actor.as_any())?; let proc_mesh = self.try_inner()?; let keepalive = self.keepalive.clone(); signal_safe_block_on(py, async move { + let unhealthy_event = unhealthy_event.lock().await; + if let Some(unhealthy_event) = unhealthy_event.clone() { + return Err(SupervisionError::new_err(format!( + "proc mesh is stopped with reason: {:?}", + unhealthy_event + ))); + } + let mailbox = proc_mesh.client().clone(); let actor_mesh = proc_mesh.spawn(&name, &pickled_type).await?; - let python_actor_mesh = PythonActorMesh { - inner: actor_mesh, - client: PyMailbox { inner: mailbox }, - _keepalive: keepalive, - }; + let actor_events = actor_mesh.with_mut(|a| a.events()).await.unwrap().unwrap(); + let python_actor_mesh = PythonActorMesh::monitored( + actor_mesh, + PyMailbox { inner: mailbox }, + keepalive, + actor_events, + ); Python::with_gil(|py| python_actor_mesh.into_py_any(py)) })? } @@ -295,16 +327,10 @@ impl PyProcMesh { "user already registered a monitor for this proc mesh".to_string(), )); } - - // Stop the default monitor - let monitor_abort = self.stop_monitor_sender.clone(); - let proc_events = self.proc_events.clone(); - + let receiver = self.user_monitor_receiver.clone(); Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move { - monitor_abort.send(true).await.unwrap(); - // Create a new user monitor - Ok(PyProcMeshMonitor { proc_events }) + Ok(PyProcMeshMonitor { receiver }) })? .into()) } @@ -344,6 +370,7 @@ impl PyProcMesh { // Grab the alloc back from `ProcEvents` and use that to stop the mesh. let mut alloc = proc_events.take().await?.into_inner().into_alloc(); alloc.stop_and_wait().await?; + anyhow::Ok(()) } .await?; @@ -383,7 +410,7 @@ impl Drop for KeepaliveState { module = "monarch._rust_bindings.monarch_hyperactor.proc_mesh" )] pub struct PyProcMeshMonitor { - proc_events: SharedCell>, + receiver: SharedCell>>, } #[pymethods] @@ -393,17 +420,17 @@ impl PyProcMeshMonitor { } fn __anext__(&self, py: Python<'_>) -> PyResult { - let events = self.proc_events.clone(); + let receiver = self.receiver.clone(); Ok(pyo3_async_runtimes::tokio::future_into_py(py, async move { - let events = events + let receiver = receiver .borrow() - .map_err(|_| PyRuntimeError::new_err("`ProcEvents` is shutdown"))?; - let mut proc_events = events.lock().await; + .map_err(|_| PyRuntimeError::new_err("`ProcEvent receiver` is shutdown"))?; + let mut proc_event_receiver = receiver.lock().await; tokio::select! { - () = events.preempted() => { - Err(PyRuntimeError::new_err("shutting down `ProcEvents`")) + () = receiver.preempted() => { + Err(PyRuntimeError::new_err("shutting down `ProcEvents` receiver")) }, - event = proc_events.next() => { + event = proc_event_receiver.recv() => { match event { Some(event) => Ok(PyProcEvent::from(event)), None => Err(::pyo3::exceptions::PyStopAsyncIteration::new_err( diff --git a/monarch_hyperactor/src/supervision.rs b/monarch_hyperactor/src/supervision.rs new file mode 100644 index 00000000..f3323759 --- /dev/null +++ b/monarch_hyperactor/src/supervision.rs @@ -0,0 +1,25 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use pyo3::create_exception; +use pyo3::exceptions::PyRuntimeError; +use pyo3::prelude::*; + +create_exception!( + monarch._rust_bindings.monarch_hyperactor.supervision, + SupervisionError, + PyRuntimeError +); + +pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> { + // Get the Python interpreter instance from the module + let py = module.py(); + // Add the exception to the module using its type object + module.add("SupervisionError", py.get_type::())?; + Ok(()) +} diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi index 90371dab..4661798a 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor_mesh.pyi @@ -6,10 +6,14 @@ # pyre-strict -from typing import final +from typing import AsyncIterator, final from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage -from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox +from monarch._rust_bindings.monarch_hyperactor.mailbox import ( + Mailbox, + OncePortReceiver, + PortReceiver, +) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.selection import Selection from monarch._rust_bindings.monarch_hyperactor.shape import Shape @@ -21,12 +25,24 @@ class PythonActorMesh: Cast a message to the selected actors in the mesh. """ + def get_supervision_event(self) -> ActorSupervisionEvent | None: + """ + Returns supervision event if there is any. + """ + ... + def get(self, rank: int) -> ActorId | None: """ Get the actor id for the actor at the given rank. """ ... + def monitor(self) -> ActorMeshMonitor: + """ + Returns a supervision monitor for this mesh. + """ + ... + @property def client(self) -> Mailbox: """ @@ -44,3 +60,75 @@ class PythonActorMesh: mesh. """ ... + +@final +class ActorMeshMonitor: + def __aiter__(self) -> AsyncIterator["ActorSupervisionEvent"]: + """ + Returns an async iterator for this monitor. + """ + ... + + async def __anext__(self) -> "ActorSupervisionEvent": + """ + Returns the next proc event in the proc mesh. + """ + ... + +@final +class MonitoredPortReceiver: + """ + A monitored receiver to which PythonMessages are sent. + """ + @classmethod + def new( + self, receiver: PortReceiver, monitor: ActorMeshMonitor + ) -> MonitoredPortReceiver: + """ + Create a new monitored receiver from a PortReceiver. + """ + ... + + async def recv(self) -> PythonMessage: + """Receive a PythonMessage from the port's sender.""" + ... + def blocking_recv(self) -> PythonMessage: + """Receive a single PythonMessage from the port's sender.""" + ... + +@final +class MonitoredOncePortReceiver: + """ + A variant of monitored PortReceiver that can only receive a single message. + """ + @classmethod + def new( + self, receiver: OncePortReceiver, monitor: ActorMeshMonitor + ) -> MonitoredOncePortReceiver: + """ + Create a new monitored receiver from a PortReceiver. + """ + ... + + async def recv(self) -> PythonMessage: + """Receive a single PythonMessage from the port's sender.""" + ... + def blocking_recv(self) -> PythonMessage: + """Receive a single PythonMessage from the port's sender.""" + ... + +@final +class ActorSupervisionEvent: + @property + def actor_id(self) -> ActorId: + """ + The actor id of the actor. + """ + ... + + @property + def actor_status(self) -> str: + """ + Detailed actor status. + """ + ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi new file mode 100644 index 00000000..1bfeec95 --- /dev/null +++ b/python/monarch/_rust_bindings/monarch_hyperactor/supervision.pyi @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import final + +@final +class SupervisionError(RuntimeError): + """ + Custom exception for supervision-related errors in monarch_hyperactor. + """ + + ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 4cd85d69..b820fcb2 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -44,7 +44,12 @@ ) from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage -from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ( + ActorMeshMonitor, + MonitoredOncePortReceiver, + MonitoredPortReceiver, + PythonActorMesh, +) from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, OncePortReceiver, @@ -54,6 +59,7 @@ ) from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape +from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator @@ -181,7 +187,18 @@ def __setstate__( self._actor_mesh = None self._shape, self._please_replace_me_actor_ids, self._mailbox = state + def _check_state(self) -> None: + # This is temporary until we have real cast integration here. We need to actively check + # supervision error here is because all communication is done through direct mailbox sending + # and not through comm actor casting. + # TODO: remove this when casting integration is done. + if self._actor_mesh is not None: + event = self._actor_mesh.get_supervision_event() + if event is not None: + raise SupervisionError(f"actor mesh is not in a healthy state: {event}") + def send(self, rank: int, message: PythonMessage) -> None: + self._check_state() actor = self._please_replace_me_actor_ids[rank] self._mailbox.post(actor, message) @@ -190,6 +207,8 @@ def cast( message: PythonMessage, selection: Selection, ) -> None: + self._check_state() + # TODO: use the actual actor mesh when available. We cannot currently use it # directly because we risk bifurcating the message delivery paths from the same # client, since slicing the mesh will produce a reference, which calls actors @@ -265,7 +284,9 @@ def _send( pass @abstractmethod - def _port(self, once: bool = False) -> "PortTuple[R]": + def _port( + self, monitor: Optional[ActorMeshMonitor], once: bool = False + ) -> "PortTuple[R]": pass # the following are all 'adverbs' or different ways to handle the @@ -279,13 +300,23 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: Load balanced RPC-style entrypoint for request/response messaging. """ - p, r = port(self, once=True) + monitor = ( + None + if self._actor_mesh._actor_mesh is None + else self._actor_mesh._actor_mesh.monitor() + ) + p, r = port(self, monitor, once=True) # pyre-ignore self._send(args, kwargs, port=p, selection="choose") return r.recv() def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - p, r = port(self, once=True) + monitor = ( + None + if self._actor_mesh._actor_mesh is None + else self._actor_mesh._actor_mesh.monitor() + ) + p, r = port(self, monitor, once=True) # pyre-ignore extent = self._send(args, kwargs, port=p, selection="choose") if extent.nelements != 1: @@ -295,9 +326,13 @@ def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: return r.recv() def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": - p: Port[R] - r: RankedPortReceiver[R] - p, r = ranked_port(self) + monitor = ( + None + if self._actor_mesh._actor_mesh is None + else self._actor_mesh._actor_mesh.monitor() + ) + + p, r = ranked_port(self, monitor) # pyre-ignore extent = self._send(args, kwargs, port=p) @@ -332,7 +367,12 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R This enables processing results from multiple actors incrementally as they become available. Returns an async generator of response values. """ - p, r = port(self) + monitor = ( + None + if self._actor_mesh._actor_mesh is None + else self._actor_mesh._actor_mesh.monitor() + ) + p, r = port(self, monitor) # pyre-ignore extent = self._send(args, kwargs, port=p) for _ in range(extent.nelements): @@ -386,8 +426,10 @@ def _send( shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) - def _port(self, once: bool = False) -> "PortTuple[R]": - return PortTuple.create(self._mailbox, once) + def _port( + self, monitor: Optional[ActorMeshMonitor], once: bool = False + ) -> "PortTuple[R]": + return PortTuple.create(self._mailbox, monitor, once) class Accumulator(Generic[P, R, A]): @@ -526,11 +568,21 @@ class PortTuple(NamedTuple, Generic[R]): receiver: "PortReceiver[R]" @staticmethod - def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + def create( + mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False + ) -> "PortTuple[Any]": handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() port_ref = handle.bind() + if monitor is not None: + receiver = ( + MonitoredOncePortReceiver.new(receiver, monitor) + if isinstance(receiver, OncePortReceiver) + else MonitoredPortReceiver.new(receiver, monitor) + ) + return PortTuple( - Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + Port(port_ref, mailbox, rank=None), + PortReceiver(mailbox, receiver), ) else: @@ -539,25 +591,37 @@ class PortTuple(NamedTuple): receiver: "PortReceiver[Any]" @staticmethod - def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + def create( + mailbox: Mailbox, monitor: Optional[ActorMeshMonitor], once: bool = False + ) -> "PortTuple[Any]": handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() port_ref = handle.bind() + if monitor is not None: + receiver = ( + MonitoredOncePortReceiver.new(receiver, monitor) + if isinstance(receiver, OncePortReceiver) + else MonitoredPortReceiver.new(receiver, monitor) + ) + return PortTuple( - Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + Port(port_ref, mailbox, rank=None), + PortReceiver(mailbox, receiver), ) # advance lower-level API for sending messages. This is intentially # not part of the Endpoint API because they way it accepts arguments # and handles concerns is different. -def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": - return endpoint._port(once) +def port( + endpoint: Endpoint[P, R], monitor: Optional[ActorMeshMonitor], once: bool = False +) -> "PortTuple[R]": + return endpoint._port(monitor, once) def ranked_port( - endpoint: Endpoint[P, R], once: bool = False + endpoint: Endpoint[P, R], monitor: Optional[ActorMeshMonitor], once: bool = False ) -> Tuple["Port[R]", "RankedPortReceiver[R]"]: - p, receiver = port(endpoint, once) + p, receiver = port(endpoint, monitor, once) return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver) @@ -565,10 +629,18 @@ class PortReceiver(Generic[R]): def __init__( self, mailbox: Mailbox, - receiver: HyPortReceiver | OncePortReceiver, + receiver: MonitoredPortReceiver + | MonitoredOncePortReceiver + | HyPortReceiver + | OncePortReceiver, ) -> None: self._mailbox: Mailbox = mailbox - self._receiver: HyPortReceiver | OncePortReceiver = receiver + self._receiver: ( + MonitoredPortReceiver + | MonitoredOncePortReceiver + | HyPortReceiver + | OncePortReceiver + ) = receiver async def _recv(self) -> R: return self._process(await self._receiver.recv()) diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index 1c19a360..206f7f09 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -28,6 +28,8 @@ import monarch.common.messages as messages import torch + +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import ActorMeshMonitor from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.shape import Shape from monarch._src.actor.actor_mesh import Extent, Port, PortTuple, Selection @@ -132,7 +134,9 @@ def _send( client._request_status() return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes) - def _port(self, once: bool = False) -> "PortTuple[R]": + def _port( + self, monitor: Optional[ActorMeshMonitor], once: bool = False + ) -> "PortTuple[R]": ambient_mesh = device_mesh._active if ambient_mesh is None: raise ValueError( @@ -144,7 +148,7 @@ def _port(self, once: bool = False) -> "PortTuple[R]": "Cannot create raw port objects with an old-style tensor engine controller." ) mailbox: Mailbox = mesh_controller._mailbox - return PortTuple.create(mailbox, once) + return PortTuple.create(mailbox, monitor, once) @property def _resolvable(self): diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 2b7973d3..8d7b3e7b 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -144,7 +144,9 @@ def fetch( defs: Tuple["Tensor", ...], uses: Tuple["Tensor", ...], ) -> "OldFuture": # the OldFuture is a lie - sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + sender, receiver = PortTuple.create( + self._mesh_controller._mailbox, None, once=True + ) ident = self.new_node(defs, uses, cast("OldFuture", sender)) process = mesh._process(shard) @@ -180,7 +182,9 @@ def shutdown( atexit.unregister(self._atexit) self._shutdown = True - sender, receiver = PortTuple.create(self._mesh_controller._mailbox, once=True) + sender, receiver = PortTuple.create( + self._mesh_controller._mailbox, None, once=True + ) self._mesh_controller.sync_at_exit(sender._port_ref.port_id) receiver.recv().get(timeout=60) # we are not expecting anything more now, because we already diff --git a/python/tests/test_actor_error.py b/python/tests/test_actor_error.py index 0f5c523e..ec515334 100644 --- a/python/tests/test_actor_error.py +++ b/python/tests/test_actor_error.py @@ -4,12 +4,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -import asyncio + import importlib.resources import subprocess +import sys import pytest from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ProcEvent +from monarch._rust_bindings.monarch_hyperactor.supervision import SupervisionError from monarch.actor import Actor, ActorError, endpoint, local_proc_mesh, proc_mesh @@ -103,6 +105,7 @@ def test_actor_exception_sync(actor_class, num_procs): exception_actor.raise_exception.call().get() +''' # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @pytest.mark.oss_skip @pytest.mark.parametrize("num_procs", [1, 2]) @@ -143,6 +146,7 @@ def test_actor_supervision(num_procs, sync_endpoint, sync_test_impl, endpoint_na assert ( process.returncode != 0 ), f"Expected non-zero exit code, got {process.returncode}" +''' # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @@ -212,6 +216,7 @@ async def test_broken_pickle_class(raise_on_getstate, raise_on_setstate, num_pro await exception_actor.print_value.call(broken_obj) +""" # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @pytest.mark.oss_skip async def test_exception_after_wait_unmonitored(): @@ -237,6 +242,7 @@ async def test_exception_after_wait_unmonitored(): assert ( process.returncode != 0 ), f"Expected non-zero exit code, got {process.returncode}" +""" # oss_skip: importlib not pulling resource correctly in git CI, needs to be revisited @@ -326,12 +332,17 @@ def is_process_running(pid): class ErrorActor(Actor): - def __init__(self, message): - raise RuntimeError("fail on init") + @endpoint + async def fail_with_supervision_error(self) -> None: + sys.exit(1) + + @endpoint + async def check(self) -> str: + return "this is a healthy check" @endpoint - async def check(self) -> None: - pass + async def check_with_exception(self) -> None: + raise RuntimeError("failed the check with app error") async def test_proc_mesh_redundant_monitoring(): @@ -344,20 +355,6 @@ async def test_proc_mesh_redundant_monitoring(): await proc.monitor() -async def test_proc_mesh_monitoring(): - proc = await local_proc_mesh(hosts=1, gpus=1) - monitor = await proc.monitor() - - with pytest.raises(Exception): - e = await proc.spawn("error", ErrorActor, "failed to init the actor") - await asyncio.wait_for(e.check.call_one(), timeout=15) - - event = await anext(monitor) - assert isinstance(event, ProcEvent.Crashed) - assert event[0] == 0 # check rank - assert "fail on init" in event[1] # check error message - - class Worker(Actor): @endpoint def work(self): @@ -385,3 +382,86 @@ async def test_errors_propagated(): with pytest.raises(ActorError) as err_info: await mesh.route.call_one() assert "value error" in str(err_info.value) + + +async def test_proc_mesh_monitoring(): + proc = await local_proc_mesh(hosts=1, gpus=1) + monitor = await proc.monitor() + + e = await proc.spawn("error", ErrorActor) + + with pytest.raises(Exception): + await e.fail_with_supervision_error.call_one() + + event = await anext(monitor) + assert isinstance(event, ProcEvent.Crashed) + assert event[0] == 0 # check rank + assert "sys.exit(1)" in event[1] # check error message + assert "fail_with_supervision_error" in event[1] # check error message + + # should not be able to spawn actors anymore as proc mesh is unhealthy + with pytest.raises(SupervisionError, match="proc mesh is stopped with reason"): + await proc.spawn("ex", ExceptionActorSync) + + +async def test_actor_mesh_supervision_handling(): + proc = await local_proc_mesh(hosts=1, gpus=1) + + e = await proc.spawn("error", ErrorActor) + + # first check() call should succeed + await e.check.call() + + # throw an application error + with pytest.raises(ActorError, match="failed the check with app error"): + await e.check_with_exception.call() + + # actor mesh should still be healthy + await e.check.call() + + # existing call should fail with supervision error + with pytest.raises(SupervisionError, match="supervision error:"): + await e.fail_with_supervision_error.call_one() + + # new call should fail with check of health state of actor mesh + with pytest.raises(SupervisionError, match="actor mesh is not in a healthy state"): + await e.check.call() + + # should not be able to spawn actors anymore as proc mesh is unhealthy + with pytest.raises(SupervisionError, match="proc mesh is stopped with reason"): + await proc.spawn("ex", ExceptionActorSync) + + +class Intermediate(Actor): + @endpoint + async def init(self): + mesh = await proc_mesh(gpus=1) + self._error_actor = await mesh.spawn("error", ErrorActor) + + @endpoint + async def forward_success(self): + return await self._error_actor.check.call() + + @endpoint + async def forward_error(self): + return await self._error_actor.fail_with_supervision_error.call_one() + + +async def test_actor_mesh_supervision_handling_chained_error(): + proc = await local_proc_mesh(hosts=1, gpus=1) + + e = await proc.spawn("intermediate", Intermediate) + await e.init.call() + + # first forward() call should succeed + await e.forward_success.call() + + # in a chain of client -> Intermediate -> ErrorActor, a supervision error + # happening in ErrorActor will be captured by Intermediate and re-raised + # as an application error (ActorError). + with pytest.raises(ActorError, match="supervision error:"): + await e.forward_error.call() + + # calling success endpoint should fail with ActorError, but with supervision msg. + with pytest.raises(ActorError, match="actor mesh is not in a healthy state"): + await e.forward_success.call()