From e6ede0ca360caec525520229c1cedd6c99cad385 Mon Sep 17 00:00:00 2001 From: zdevito Date: Tue, 15 Jul 2025 10:52:54 -0700 Subject: [PATCH] [17/n] Pass monarch tensors to actor endpoints, part 2, actor messages sent in order When you send a monarch.Tensor to an actor, what actually happens is that you instruct the stream managing the tensor to send the tensor to the actor. In part 1, that stream directly told the actor what method to call, and where to send the result. The issue with that is ordering: from the perspective of the original caller it is as if we are sending a message to the actor directly. If the stream sends the message it is possible that it arrives out of order with respect to a message that goes directly to the actor (e.g. a message that contains no monarch.Tensors). This resolves ordering issue: now the client sends a message to the actor 'CallMethodIndirect' that provides the pickled args/kwargs and method to invoke. The local python values of the torch.Tensors still need to come from Stream actor, so CallMethodIndirect looks them up by messaging the (new) LocalStateBrokerActor which can transfer PyObjects between local actors without serialization. Ordering is guarenteed for both tensors and actors because a message is sent to both the receiving actor _and_ the stream actor at the same time. Differential Revision: [D78314012](https://our.internmc.facebook.com/intern/diff/D78314012/) [ghstack-poisoned] --- monarch_extension/src/convert.rs | 36 +---- monarch_extension/src/mesh_controller.rs | 16 +++ monarch_hyperactor/src/actor.rs | 125 +++++++++++------- monarch_hyperactor/src/lib.rs | 1 + monarch_hyperactor/src/local_state_broker.rs | 94 +++++++++++++ monarch_hyperactor/src/mailbox.rs | 6 + monarch_messages/src/worker.rs | 27 +--- monarch_simulator/src/worker.rs | 2 +- monarch_tensor_worker/src/stream.rs | 59 ++++----- .../monarch_extension/mesh_controller.pyi | 3 + .../monarch_hyperactor/actor.pyi | 21 ++- python/monarch/_src/actor/actor_mesh.py | 7 +- python/monarch/common/messages.py | 8 +- python/monarch/mesh_controller.py | 49 ++++--- python/tests/_monarch/test_hyperactor.py | 4 +- python/tests/_monarch/test_mailbox.py | 4 +- python/tests/test_tensor_engine.py | 26 ++++ 17 files changed, 320 insertions(+), 168 deletions(-) create mode 100644 monarch_hyperactor/src/local_state_broker.rs diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index 5af8b32d..53b069a7 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -176,35 +176,6 @@ impl<'a> MessageParser<'a> { fn parseWorkerMessageList(&self, name: &str) -> PyResult> { self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect() } - #[allow(non_snake_case)] - fn parseLocalStateList(&self, name: &str) -> PyResult> { - let local_state_list = self.attr(name)?; - let mut result = Vec::new(); - - // Get the PyMailbox class for type checking - let mailbox_class = self - .current - .py() - .import("monarch._rust_bindings.monarch_hyperactor.mailbox")? - .getattr("Mailbox")?; - - for item in local_state_list.try_iter()? { - let item = item?; - // Check if it's a Ref by trying to extract it - if let Ok(ref_val) = create_ref(item.clone()) { - result.push(worker::LocalState::Ref(ref_val)); - } else if item.is_instance(&mailbox_class)? { - // It's a PyMailbox instance - result.push(worker::LocalState::Mailbox); - } else { - return Err(PyValueError::new_err(format!( - "Expected Ref or Mailbox in local_state, got: {}", - item.get_type().name()? - ))); - } - } - Ok(result) - } fn parse_error_reason(&self, name: &str) -> PyResult, String)>> { let err = self.attr(name)?; if err.is_none() { @@ -465,11 +436,8 @@ fn create_map(py: Python) -> HashMap { Ok(WorkerMessage::SendResultOfActorCall( worker::ActorCallParams { seq: p.parseSeq("seq")?, - actor: p.parse("actor")?, - index: p.parse("actor_index")?, - method: p.parse("method")?, - args_kwargs_tuple: p.parse("args_kwargs_tuple")?, - local_state: p.parseLocalStateList("local_state")?, + broker_id: p.parse("broker_id")?, + local_state: p.parseRefList("local_state")?, mutates: p.parseRefList("mutates")?, stream: p.parseStreamRef("stream")?, }, diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index d53aff97..3f3bbed2 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -38,6 +38,7 @@ use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; +use monarch_hyperactor::local_state_broker::LocalStateBrokerActor; use monarch_hyperactor::mailbox::PyPortId; use monarch_hyperactor::ndslice::PySlice; use monarch_hyperactor::proc_mesh::PyProcMesh; @@ -78,6 +79,7 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult struct _Controller { controller_handle: Arc>>, all_ranks: Slice, + broker_id: (String, usize), } static NEXT_ID: AtomicUsize = AtomicUsize::new(0); @@ -123,9 +125,17 @@ impl _Controller { Ok(Self { controller_handle, all_ranks, + // note that 0 is the _pid_ of the broker, which will be 0 for + // top-level spawned actors. + broker_id: (format!("tensor_engine_brokers_{}", id), 0), }) } + #[getter] + fn broker_id(&self) -> (String, usize) { + self.broker_id.clone() + } + #[pyo3(signature = (seq, defs, uses, response_port, tracebacks))] fn node<'py>( &mut self, @@ -626,6 +636,7 @@ enum ClientToControllerMessage { struct MeshControllerActor { proc_mesh: SharedCell, workers: Option>>, + brokers: Option>>, history: History, id: usize, debugger_active: Option>, @@ -730,6 +741,7 @@ impl Actor for MeshControllerActor { Ok(MeshControllerActor { proc_mesh: proc_mesh.clone(), workers: None, + brokers: None, history: History::new(world_size), id, debugger_active: None, @@ -758,6 +770,10 @@ impl Actor for MeshControllerActor { .cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?; self.workers = Some(workers); + let brokers = proc_mesh + .spawn(&format!("tensor_engine_brokers_{}", self.id), &()) + .await?; + self.brokers = Some(brokers); Ok(()) } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 84069e36..8d5192be 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -21,6 +21,7 @@ use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Named; +use hyperactor::cap::CanSend; use hyperactor::forward; use hyperactor::message::Bind; use hyperactor::message::Bindings; @@ -43,6 +44,8 @@ use tokio::sync::Mutex; use tokio::sync::oneshot; use crate::config::SHARED_ASYNCIO_RUNTIME; +use crate::local_state_broker::BrokerId; +use crate::local_state_broker::LocalStateBrokerMessage; use crate::mailbox::EitherPortRef; use crate::mailbox::PyMailbox; use crate::proc::InstanceWrapper; @@ -171,6 +174,13 @@ impl PickledMessageClientActor { } } +#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum UnflattenArg { + Mailbox, + PyObject, +} + #[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)] pub enum PythonMessageKind { @@ -185,6 +195,14 @@ pub enum PythonMessageKind { rank: Option, }, Uninit {}, + CallMethodIndirect { + name: String, + local_state_broker: (String, usize), + id: usize, + // specify whether the argument to unflatten the local mailbox, + // or the next argument of the local state. + unflatten_args: Vec, + }, } impl Default for PythonMessageKind { @@ -193,6 +211,11 @@ impl Default for PythonMessageKind { } } +fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, PyAny> { + let mailbox: PyMailbox = cx.mailbox_for_py().clone().into(); + mailbox.into_bound_py_any(py).unwrap() +} + #[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)] pub struct PythonMessage { @@ -219,6 +242,56 @@ impl PythonMessage { _ => panic!("PythonMessage is not a response but {:?}", self), } } + + pub async fn resolve_indirect_call( + mut self, + cx: &Context<'_, T>, + ) -> anyhow::Result<(Self, PyObject)> { + let local_state: PyObject; + match self.kind { + PythonMessageKind::CallMethodIndirect { + name, + local_state_broker, + id, + unflatten_args, + } => { + let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap(); + let (send, recv) = cx.open_once_port(); + broker.send(LocalStateBrokerMessage::Get(id, send))?; + let state = recv.recv().await?; + let mut state_it = state.state.into_iter(); + local_state = Python::with_gil(|py| { + let mailbox = mailbox(py, cx); + PyList::new( + py, + unflatten_args.into_iter().map(|x| -> Bound<'_, PyAny> { + match x { + UnflattenArg::Mailbox => mailbox.clone(), + UnflattenArg::PyObject => state_it.next().unwrap().into_bound(py), + } + }), + ) + .unwrap() + .into() + }); + self.kind = PythonMessageKind::CallMethod { + name, + response_port: Some(state.response_port), + } + } + _ => { + local_state = Python::with_gil(|py| { + let mailbox = mailbox(py, cx); + py.import("itertools") + .unwrap() + .call_method1("repeat", (mailbox.clone(),)) + .unwrap() + .unbind() + }); + } + }; + Ok((self, local_state)) + } } impl std::fmt::Debug for PythonMessage { @@ -408,30 +481,16 @@ impl PanicFlag { } #[async_trait] -impl Handler for PythonActor { - async fn handle( - &mut self, - cx: &Context, - message: LocalPythonMessage, - ) -> anyhow::Result<()> { - let mailbox: PyMailbox = PyMailbox { - inner: cx.mailbox_for_py().clone(), - }; +impl Handler for PythonActor { + async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { + let (message, local_state) = message.resolve_indirect_call(cx).await?; + // Create a channel for signaling panics in async endpoints. // See [Panics in async endpoints]. let (sender, receiver) = oneshot::channel(); let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> { - let mailbox = mailbox.into_bound_py_any(py).unwrap(); - let local_state: Option>> = message.local_state.map(|state| { - state - .into_iter() - .map(|e| match e { - LocalState::Mailbox => mailbox.clone(), - LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(), - }) - .collect() - }); + let mailbox = mailbox(py, cx); let (rank, shape) = cx.cast_info(); let awaitable = self.actor.call_method( py, @@ -440,7 +499,7 @@ impl Handler for PythonActor { mailbox, rank, PyShape::from(shape), - message.message, + message, PanicFlag { sender: Some(sender), }, @@ -463,20 +522,6 @@ impl Handler for PythonActor { } } -#[async_trait] -impl Handler for PythonActor { - async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { - self.handle( - cx, - LocalPythonMessage { - message, - local_state: None, - }, - ) - .await - } -} - /// Helper struct to make a Python future passable in an actor message. /// /// Also so that we don't have to write this massive type signature everywhere @@ -576,17 +621,6 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask { Ok(()) } } -#[derive(Debug)] -pub enum LocalState { - Mailbox, - PyObject(PyObject), -} - -#[derive(Debug)] -pub struct LocalPythonMessage { - pub message: PythonMessage, - pub local_state: Option>, -} pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; @@ -594,6 +628,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/lib.rs b/monarch_hyperactor/src/lib.rs index 175e10bd..2d4971c8 100644 --- a/monarch_hyperactor/src/lib.rs +++ b/monarch_hyperactor/src/lib.rs @@ -14,6 +14,7 @@ pub mod alloc; pub mod bootstrap; pub mod channel; pub mod config; +pub mod local_state_broker; pub mod mailbox; pub mod ndslice; pub mod proc; diff --git a/monarch_hyperactor/src/local_state_broker.rs b/monarch_hyperactor/src/local_state_broker.rs new file mode 100644 index 00000000..32d2c500 --- /dev/null +++ b/monarch_hyperactor/src/local_state_broker.rs @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; + +use async_trait::async_trait; +use hyperactor::Actor; +use hyperactor::ActorHandle; +use hyperactor::ActorId; +use hyperactor::ActorRef; +use hyperactor::Context; +use hyperactor::Handler; +use hyperactor::OncePortHandle; +use pyo3::prelude::*; + +use crate::mailbox::EitherPortRef; + +#[derive(Debug)] +pub struct LocalState { + pub response_port: EitherPortRef, + pub state: Vec, +} + +#[derive(Debug)] +pub enum LocalStateBrokerMessage { + Set(usize, LocalState), + Get(usize, OncePortHandle), +} + +#[derive(Debug)] +#[hyperactor::export(spawn = true)] +pub struct LocalStateBrokerActor { + states: HashMap, + ports: HashMap>, +} + +#[async_trait] +impl Actor for LocalStateBrokerActor { + type Params = (); + async fn new(_params: Self::Params) -> anyhow::Result { + Ok(Self { + states: HashMap::new(), + ports: HashMap::new(), + }) + } +} + +#[async_trait] +impl Handler for LocalStateBrokerActor { + async fn handle( + &mut self, + cx: &Context, + message: LocalStateBrokerMessage, + ) -> anyhow::Result<()> { + match message { + LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) { + Some((_, port)) => { + port.send(state)?; + } + None => { + self.states.insert(id, state); + } + }, + LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) { + Some((_, state)) => { + port.send(state)?; + } + None => { + self.ports.insert(id, port); + } + }, + } + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct BrokerId(String, usize); + +impl BrokerId { + pub fn new(broker_id: (String, usize)) -> Self { + BrokerId(broker_id.0, broker_id.1) + } + pub fn resolve(self, cx: &Context) -> Option> { + let actor_id = ActorId(cx.proc().proc_id().clone(), self.0, self.1); + let actor_ref: ActorRef = ActorRef::attest(actor_id); + actor_ref.downcast_handle(cx) + } +} diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 7e657b24..2442838d 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -208,6 +208,12 @@ impl From for PortId { } } +impl From for PyMailbox { + fn from(inner: Mailbox) -> Self { + PyMailbox { inner } + } +} + #[pymethods] impl PyPortId { #[new] diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index b60d2cfe..0192796f 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -402,32 +402,15 @@ pub struct CallFunctionParams { pub remote_process_groups: Vec, } -/// The local state that has to be restored into the python_message during -/// its unpickling. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum LocalState { - Ref(Ref), - Mailbox, -} - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ActorCallParams { pub seq: Seq, - /// The actor to call is (proc_id of stream worker, 'actor', 'index') - pub actor: String, - pub index: usize, - - // method name to call - pub method: String, - // pickled arguments, that will need to be patched with local state - pub args_kwargs_tuple: Vec, - /// Referenceable objects to pass to the actor, + // The BrokerId but we do not depend on hyperactor in messages. + pub broker_id: (String, usize), + /// Referenceable objects to pass to the actor as LocalState, /// these will be put into the PythonMessage - /// during its unpickling. Because unpickling also needs to - /// restore mailboxes, we also have to keep track of which - /// members should just be the mailbox object. - pub local_state: Vec, - + /// during its unpickling. + pub local_state: Vec, /// Tensors that will be mutated by the call. pub mutates: Vec, pub stream: StreamRef, diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index d38e774e..e1c528ac 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -314,7 +314,7 @@ impl WorkerMessageHandler for WorkerActor { async fn send_result_of_actor_call( &mut self, cx: &hyperactor::Context, - params: ActorCallParams, + _params: ActorCallParams, ) -> Result<()> { bail!("unimplemented: send_result_of_actor_call"); } diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index c7031ae1..7a5305c5 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -36,10 +36,11 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; -use monarch_hyperactor::actor::LocalPythonMessage; -use monarch_hyperactor::actor::PythonActor; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; +use monarch_hyperactor::local_state_broker::BrokerId; +use monarch_hyperactor::local_state_broker::LocalState; +use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage; use monarch_hyperactor::mailbox::EitherPortRef; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; @@ -47,7 +48,6 @@ use monarch_messages::controller::WorkerError; use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; -use monarch_messages::worker::LocalState; use monarch_messages::worker::StreamRef; use monarch_types::PyTree; use monarch_types::SerializablePyErr; @@ -1681,40 +1681,33 @@ impl StreamMessageHandler for StreamActor { params: ActorCallParams, ) -> anyhow::Result<()> { // TODO: handle mutates - let actor_id = ActorId(cx.proc().proc_id().clone(), params.actor, params.index); - let actor_ref: ActorRef = ActorRef::attest(actor_id); - let actor_handle = actor_ref.downcast_handle(cx).unwrap(); + let local_state: Result> = Python::with_gil(|py| { + params + .local_state + .into_iter() + .map(|elem| { + // SAFETY: python is gonna make unsafe copies of this stuff anyway + unsafe { + let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into(); + Ok(x) + } + }) + .collect() + }); + let (send, recv) = cx.open_once_port(); let send = send.bind(); let send = EitherPortRef::Once(send.into()); - let message = PythonMessage::new_from_buf( - PythonMessageKind::CallMethod { - name: params.method, - response_port: Some(send), - }, - params.args_kwargs_tuple.into(), - ); - let local_state: Result> = - Python::with_gil(|py| { - params - .local_state - .into_iter() - .map(|elem| match elem { - LocalState::Mailbox => Ok(monarch_hyperactor::actor::LocalState::Mailbox), - // SAFETY: python is gonna make unsafe copies of this stuff anyway - LocalState::Ref(r) => unsafe { - let x = self.ref_to_rvalue(&r)?.try_to_object_unsafe(py)?.into(); - Ok(monarch_hyperactor::actor::LocalState::PyObject(x)) - }, - }) - .collect() - }); - // including making the PyMailbox - let message = LocalPythonMessage { - message, - local_state: Some(local_state?), + + let state = LocalState { + response_port: send, + state: local_state?, }; - actor_handle.send(message)?; + let x: u64 = params.seq.into(); + let message = LocalStateBrokerMessage::Set(x as usize, state); + + let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap(); + broker.send(message)?; let result = recv.recv().await?; match result.kind { PythonMessageKind::Exception { .. } => { diff --git a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi index a256c556..e0ae3760 100644 --- a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi @@ -40,3 +40,6 @@ class _Controller: to any future. """ ... + + @property + def broker_id(self) -> Tuple[str, int]: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index c6cedd71..a306b9b4 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -7,8 +7,9 @@ # pyre-strict import abc +from enum import Enum -from typing import Any, final, List, Optional, Protocol, Type +from typing import Any, final, Iterable, List, Optional, Protocol, Tuple, Type from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -109,6 +110,9 @@ class PythonMessageKind: @classmethod @property def Uninit(cls) -> "Type[Uninit]": ... + @classmethod + @property + def CallMethodIndirect(cls) -> "Type[CallMethodIndirect]": ... class Result(PythonMessageKind): def __init__(self, rank: Optional[int]) -> None: ... @@ -129,6 +133,19 @@ class CallMethod(PythonMessageKind): @property def response_port(self) -> PortRef | OncePortRef | None: ... +class UnflattenArg(Enum): + Mailbox = 0 + PyObject = 1 + +class CallMethodIndirect(PythonMessageKind): + def __init__( + self, + name: str, + broker_id: Tuple[str, int], + id: int, + unflatten_args: List[UnflattenArg], + ) -> None: ... + class Init(PythonMessageKind): def __init__(self, response_port: PortRef | OncePortRef | None) -> None: ... @property @@ -210,5 +227,5 @@ class Actor(Protocol): shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 198dfc95..35ef6a4f 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -424,7 +424,7 @@ def _send( self._actor_mesh.cast(message, selection) else: importlib.import_module("monarch." + "mesh_controller").actor_send( - self, self._name, bytes, refs, port + self, bytes, refs, port, selection ) shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) @@ -720,11 +720,8 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: - if local_state is None: - local_state = itertools.repeat(mailbox) - match message.kind: case PythonMessageKind.CallMethod(response_port=response_port): pass diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index ca68a9f2..5ebfbebc 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -17,6 +17,7 @@ NamedTuple, Optional, Protocol, + Sequence, Tuple, TYPE_CHECKING, ) @@ -428,11 +429,8 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: class SendResultOfActorCall(NamedTuple): seq: int - actor: str - actor_index: int - method: str - args_kwargs_tuple: bytes - local_state: List[Referenceable | Mailbox] + broker_id: Tuple[str, int] + local_state: Sequence[Tensor | tensor_worker.Ref] mutates: List[tensor_worker.Ref] stream: tensor_worker.StreamRef diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index c69322f6..6545a926 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -30,11 +30,16 @@ WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.actor import ( + PythonMessage, + PythonMessageKind, + UnflattenArg, +) from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple +from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple, Selection from monarch._src.actor.shape import NDSlice from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController @@ -264,12 +269,16 @@ def __str__(self): def actor_send( - actor_mesh: ActorMeshRef, - method: str, + endpoint: ActorEndpoint, args_kwargs_tuple: bytes, refs: Sequence[Any], port: Optional[Port[Any]], + selection: Selection, ): + unflatten_args = [ + UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox + for ref in refs + ] tensors = [ref for ref in refs if isinstance(ref, Tensor)] # we have some monarch references, we need to ensure their # proc_mesh matches that of the tensors we sent to it @@ -284,8 +293,7 @@ def actor_send( # mutates checker.check_permission(()) selected_device_mesh = ( - actor_mesh._actor_mesh._proc_mesh - and actor_mesh._actor_mesh._proc_mesh._device_mesh + endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh ) if selected_device_mesh is not checker.mesh: raise ValueError( @@ -293,26 +301,33 @@ def actor_send( "NYI: better serialization of mesh names to make the mismatch more clear." ) - client = checker.mesh.client + client = cast(MeshClient, checker.mesh.client) + + broker_id: Tuple[str, int] = client._mesh_controller.broker_id + stream_ref = chosen_stream._to_ref(client) fut = (port, checker.mesh._ndslice) if port is not None else None ident = client.new_node([], tensors, cast("OldFuture", fut)) - actor_name, actor_index = actor_mesh._actor_mesh._name_pid - msg = SendResultOfActorCall( - ident, - actor_name, - actor_index, - method, + # To ensure that both the actor and the stream execute in order, we send a message + # to each at this point. The message to the worker will be handled on the stream actor where + # it will send the 'tensor's to the broker actor locally, along with a response port with the + # computed value. + + # The message to the generic actor tells it to first wait on the broker to get the local arguments + # from the stream, then it will run the actor method, and send the result to response port. + + actor_msg = PythonMessage( + PythonMessageKind.CallMethodIndirect( + endpoint._name, broker_id, ident, unflatten_args + ), args_kwargs_tuple, - cast("List[Mailbox | Referenceable]", refs), - [], - stream_ref, ) - - client.send(checker.mesh._ndslice, msg) + endpoint._actor_mesh.cast(actor_msg, selection) + worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref) + client.send(checker.mesh._ndslice, worker_msg) # we have to ask for status updates # from workers to be sure they have finished # enough work to count this future as finished, diff --git a/python/tests/_monarch/test_hyperactor.py b/python/tests/_monarch/test_hyperactor.py index 74eb262e..f8e2448e 100644 --- a/python/tests/_monarch/test_hyperactor.py +++ b/python/tests/_monarch/test_hyperactor.py @@ -10,7 +10,7 @@ import os import signal import time -from typing import Any, List +from typing import Any, Iterable import monarch @@ -34,7 +34,7 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: raise NotImplementedError() diff --git a/python/tests/_monarch/test_mailbox.py b/python/tests/_monarch/test_mailbox.py index cd1468f5..867d99b3 100644 --- a/python/tests/_monarch/test_mailbox.py +++ b/python/tests/_monarch/test_mailbox.py @@ -8,7 +8,7 @@ import asyncio import pickle -from typing import Any, Callable, cast, final, Generic, List, TYPE_CHECKING, TypeVar +from typing import Any, Callable, cast, final, Generic, Iterable, TYPE_CHECKING, TypeVar import monarch @@ -137,7 +137,7 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: call_method = cast("CallMethod", message.kind) assert call_method.response_port is not None diff --git a/python/tests/test_tensor_engine.py b/python/tests/test_tensor_engine.py index a79d5e8d..fe3dcc14 100644 --- a/python/tests/test_tensor_engine.py +++ b/python/tests/test_tensor_engine.py @@ -78,3 +78,29 @@ def test_actor_with_tensors() -> None: x = pm.spawn("adder", AddWithState, torch.ones(())).get() y = torch.ones(()) assert x.forward.call(y).get(timeout=5).item(hosts=0, gpus=0).item() == 2 + + +class Counter(Actor): + def __init__(self): + super().__init__() + self.c = 0 + + @endpoint + def incr(self, x) -> int: + self.c += 1 + return self.c - 1 + + +@two_gpu +def test_actor_tensor_ordering() -> None: + pm = proc_mesh(gpus=1).get() + with pm.activate(): + counter = pm.spawn("a", Counter).get() + results = [] + for _ in range(0, 10, 2): + # tensor engine call + results.append(counter.incr.call(torch.ones(()))) + # non-tensor engine call + results.append(counter.incr.call(1)) + + assert list(range(10)) == [r.get().item(hosts=0, gpus=0) for r in results]