diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index 5af8b32d..53b069a7 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -176,35 +176,6 @@ impl<'a> MessageParser<'a> { fn parseWorkerMessageList(&self, name: &str) -> PyResult> { self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect() } - #[allow(non_snake_case)] - fn parseLocalStateList(&self, name: &str) -> PyResult> { - let local_state_list = self.attr(name)?; - let mut result = Vec::new(); - - // Get the PyMailbox class for type checking - let mailbox_class = self - .current - .py() - .import("monarch._rust_bindings.monarch_hyperactor.mailbox")? - .getattr("Mailbox")?; - - for item in local_state_list.try_iter()? { - let item = item?; - // Check if it's a Ref by trying to extract it - if let Ok(ref_val) = create_ref(item.clone()) { - result.push(worker::LocalState::Ref(ref_val)); - } else if item.is_instance(&mailbox_class)? { - // It's a PyMailbox instance - result.push(worker::LocalState::Mailbox); - } else { - return Err(PyValueError::new_err(format!( - "Expected Ref or Mailbox in local_state, got: {}", - item.get_type().name()? - ))); - } - } - Ok(result) - } fn parse_error_reason(&self, name: &str) -> PyResult, String)>> { let err = self.attr(name)?; if err.is_none() { @@ -465,11 +436,8 @@ fn create_map(py: Python) -> HashMap { Ok(WorkerMessage::SendResultOfActorCall( worker::ActorCallParams { seq: p.parseSeq("seq")?, - actor: p.parse("actor")?, - index: p.parse("actor_index")?, - method: p.parse("method")?, - args_kwargs_tuple: p.parse("args_kwargs_tuple")?, - local_state: p.parseLocalStateList("local_state")?, + broker_id: p.parse("broker_id")?, + local_state: p.parseRefList("local_state")?, mutates: p.parseRefList("mutates")?, stream: p.parseStreamRef("stream")?, }, diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index d53aff97..3f3bbed2 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -38,6 +38,7 @@ use hyperactor_mesh::shared_cell::SharedCell; use hyperactor_mesh::shared_cell::SharedCellRef; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; +use monarch_hyperactor::local_state_broker::LocalStateBrokerActor; use monarch_hyperactor::mailbox::PyPortId; use monarch_hyperactor::ndslice::PySlice; use monarch_hyperactor::proc_mesh::PyProcMesh; @@ -78,6 +79,7 @@ pub(crate) fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult struct _Controller { controller_handle: Arc>>, all_ranks: Slice, + broker_id: (String, usize), } static NEXT_ID: AtomicUsize = AtomicUsize::new(0); @@ -123,9 +125,17 @@ impl _Controller { Ok(Self { controller_handle, all_ranks, + // note that 0 is the _pid_ of the broker, which will be 0 for + // top-level spawned actors. + broker_id: (format!("tensor_engine_brokers_{}", id), 0), }) } + #[getter] + fn broker_id(&self) -> (String, usize) { + self.broker_id.clone() + } + #[pyo3(signature = (seq, defs, uses, response_port, tracebacks))] fn node<'py>( &mut self, @@ -626,6 +636,7 @@ enum ClientToControllerMessage { struct MeshControllerActor { proc_mesh: SharedCell, workers: Option>>, + brokers: Option>>, history: History, id: usize, debugger_active: Option>, @@ -730,6 +741,7 @@ impl Actor for MeshControllerActor { Ok(MeshControllerActor { proc_mesh: proc_mesh.clone(), workers: None, + brokers: None, history: History::new(world_size), id, debugger_active: None, @@ -758,6 +770,10 @@ impl Actor for MeshControllerActor { .cast(selection::dsl::true_(), AssignRankMessage::AssignRank())?; self.workers = Some(workers); + let brokers = proc_mesh + .spawn(&format!("tensor_engine_brokers_{}", self.id), &()) + .await?; + self.brokers = Some(brokers); Ok(()) } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 84069e36..8d5192be 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -21,6 +21,7 @@ use hyperactor::Context; use hyperactor::HandleClient; use hyperactor::Handler; use hyperactor::Named; +use hyperactor::cap::CanSend; use hyperactor::forward; use hyperactor::message::Bind; use hyperactor::message::Bindings; @@ -43,6 +44,8 @@ use tokio::sync::Mutex; use tokio::sync::oneshot; use crate::config::SHARED_ASYNCIO_RUNTIME; +use crate::local_state_broker::BrokerId; +use crate::local_state_broker::LocalStateBrokerMessage; use crate::mailbox::EitherPortRef; use crate::mailbox::PyMailbox; use crate::proc::InstanceWrapper; @@ -171,6 +174,13 @@ impl PickledMessageClientActor { } } +#[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum UnflattenArg { + Mailbox, + PyObject, +} + #[pyclass(module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Debug, Serialize, Deserialize, Named, PartialEq)] pub enum PythonMessageKind { @@ -185,6 +195,14 @@ pub enum PythonMessageKind { rank: Option, }, Uninit {}, + CallMethodIndirect { + name: String, + local_state_broker: (String, usize), + id: usize, + // specify whether the argument to unflatten the local mailbox, + // or the next argument of the local state. + unflatten_args: Vec, + }, } impl Default for PythonMessageKind { @@ -193,6 +211,11 @@ impl Default for PythonMessageKind { } } +fn mailbox<'py, T: Actor>(py: Python<'py>, cx: &Context<'_, T>) -> Bound<'py, PyAny> { + let mailbox: PyMailbox = cx.mailbox_for_py().clone().into(); + mailbox.into_bound_py_any(py).unwrap() +} + #[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Clone, Serialize, Deserialize, Named, PartialEq, Default)] pub struct PythonMessage { @@ -219,6 +242,56 @@ impl PythonMessage { _ => panic!("PythonMessage is not a response but {:?}", self), } } + + pub async fn resolve_indirect_call( + mut self, + cx: &Context<'_, T>, + ) -> anyhow::Result<(Self, PyObject)> { + let local_state: PyObject; + match self.kind { + PythonMessageKind::CallMethodIndirect { + name, + local_state_broker, + id, + unflatten_args, + } => { + let broker = BrokerId::new(local_state_broker).resolve(cx).unwrap(); + let (send, recv) = cx.open_once_port(); + broker.send(LocalStateBrokerMessage::Get(id, send))?; + let state = recv.recv().await?; + let mut state_it = state.state.into_iter(); + local_state = Python::with_gil(|py| { + let mailbox = mailbox(py, cx); + PyList::new( + py, + unflatten_args.into_iter().map(|x| -> Bound<'_, PyAny> { + match x { + UnflattenArg::Mailbox => mailbox.clone(), + UnflattenArg::PyObject => state_it.next().unwrap().into_bound(py), + } + }), + ) + .unwrap() + .into() + }); + self.kind = PythonMessageKind::CallMethod { + name, + response_port: Some(state.response_port), + } + } + _ => { + local_state = Python::with_gil(|py| { + let mailbox = mailbox(py, cx); + py.import("itertools") + .unwrap() + .call_method1("repeat", (mailbox.clone(),)) + .unwrap() + .unbind() + }); + } + }; + Ok((self, local_state)) + } } impl std::fmt::Debug for PythonMessage { @@ -408,30 +481,16 @@ impl PanicFlag { } #[async_trait] -impl Handler for PythonActor { - async fn handle( - &mut self, - cx: &Context, - message: LocalPythonMessage, - ) -> anyhow::Result<()> { - let mailbox: PyMailbox = PyMailbox { - inner: cx.mailbox_for_py().clone(), - }; +impl Handler for PythonActor { + async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { + let (message, local_state) = message.resolve_indirect_call(cx).await?; + // Create a channel for signaling panics in async endpoints. // See [Panics in async endpoints]. let (sender, receiver) = oneshot::channel(); let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> { - let mailbox = mailbox.into_bound_py_any(py).unwrap(); - let local_state: Option>> = message.local_state.map(|state| { - state - .into_iter() - .map(|e| match e { - LocalState::Mailbox => mailbox.clone(), - LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(), - }) - .collect() - }); + let mailbox = mailbox(py, cx); let (rank, shape) = cx.cast_info(); let awaitable = self.actor.call_method( py, @@ -440,7 +499,7 @@ impl Handler for PythonActor { mailbox, rank, PyShape::from(shape), - message.message, + message, PanicFlag { sender: Some(sender), }, @@ -463,20 +522,6 @@ impl Handler for PythonActor { } } -#[async_trait] -impl Handler for PythonActor { - async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { - self.handle( - cx, - LocalPythonMessage { - message, - local_state: None, - }, - ) - .await - } -} - /// Helper struct to make a Python future passable in an actor message. /// /// Also so that we don't have to write this massive type signature everywhere @@ -576,17 +621,6 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask { Ok(()) } } -#[derive(Debug)] -pub enum LocalState { - Mailbox, - PyObject(PyObject), -} - -#[derive(Debug)] -pub struct LocalPythonMessage { - pub message: PythonMessage, - pub local_state: Option>, -} pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; @@ -594,6 +628,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; + hyperactor_mod.add_class::()?; hyperactor_mod.add_class::()?; Ok(()) } diff --git a/monarch_hyperactor/src/lib.rs b/monarch_hyperactor/src/lib.rs index 175e10bd..2d4971c8 100644 --- a/monarch_hyperactor/src/lib.rs +++ b/monarch_hyperactor/src/lib.rs @@ -14,6 +14,7 @@ pub mod alloc; pub mod bootstrap; pub mod channel; pub mod config; +pub mod local_state_broker; pub mod mailbox; pub mod ndslice; pub mod proc; diff --git a/monarch_hyperactor/src/local_state_broker.rs b/monarch_hyperactor/src/local_state_broker.rs new file mode 100644 index 00000000..32d2c500 --- /dev/null +++ b/monarch_hyperactor/src/local_state_broker.rs @@ -0,0 +1,94 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use std::collections::HashMap; + +use async_trait::async_trait; +use hyperactor::Actor; +use hyperactor::ActorHandle; +use hyperactor::ActorId; +use hyperactor::ActorRef; +use hyperactor::Context; +use hyperactor::Handler; +use hyperactor::OncePortHandle; +use pyo3::prelude::*; + +use crate::mailbox::EitherPortRef; + +#[derive(Debug)] +pub struct LocalState { + pub response_port: EitherPortRef, + pub state: Vec, +} + +#[derive(Debug)] +pub enum LocalStateBrokerMessage { + Set(usize, LocalState), + Get(usize, OncePortHandle), +} + +#[derive(Debug)] +#[hyperactor::export(spawn = true)] +pub struct LocalStateBrokerActor { + states: HashMap, + ports: HashMap>, +} + +#[async_trait] +impl Actor for LocalStateBrokerActor { + type Params = (); + async fn new(_params: Self::Params) -> anyhow::Result { + Ok(Self { + states: HashMap::new(), + ports: HashMap::new(), + }) + } +} + +#[async_trait] +impl Handler for LocalStateBrokerActor { + async fn handle( + &mut self, + cx: &Context, + message: LocalStateBrokerMessage, + ) -> anyhow::Result<()> { + match message { + LocalStateBrokerMessage::Set(id, state) => match self.ports.remove_entry(&id) { + Some((_, port)) => { + port.send(state)?; + } + None => { + self.states.insert(id, state); + } + }, + LocalStateBrokerMessage::Get(id, port) => match self.states.remove_entry(&id) { + Some((_, state)) => { + port.send(state)?; + } + None => { + self.ports.insert(id, port); + } + }, + } + Ok(()) + } +} + +#[derive(Debug, Clone)] +pub struct BrokerId(String, usize); + +impl BrokerId { + pub fn new(broker_id: (String, usize)) -> Self { + BrokerId(broker_id.0, broker_id.1) + } + pub fn resolve(self, cx: &Context) -> Option> { + let actor_id = ActorId(cx.proc().proc_id().clone(), self.0, self.1); + let actor_ref: ActorRef = ActorRef::attest(actor_id); + actor_ref.downcast_handle(cx) + } +} diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 6e765403..425301e8 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -214,6 +214,12 @@ impl From for PortId { } } +impl From for PyMailbox { + fn from(inner: Mailbox) -> Self { + PyMailbox { inner } + } +} + #[pymethods] impl PyPortId { #[new] diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index b60d2cfe..0192796f 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -402,32 +402,15 @@ pub struct CallFunctionParams { pub remote_process_groups: Vec, } -/// The local state that has to be restored into the python_message during -/// its unpickling. -#[derive(Serialize, Deserialize, Debug, Clone)] -pub enum LocalState { - Ref(Ref), - Mailbox, -} - #[derive(Serialize, Deserialize, Debug, Clone)] pub struct ActorCallParams { pub seq: Seq, - /// The actor to call is (proc_id of stream worker, 'actor', 'index') - pub actor: String, - pub index: usize, - - // method name to call - pub method: String, - // pickled arguments, that will need to be patched with local state - pub args_kwargs_tuple: Vec, - /// Referenceable objects to pass to the actor, + // The BrokerId but we do not depend on hyperactor in messages. + pub broker_id: (String, usize), + /// Referenceable objects to pass to the actor as LocalState, /// these will be put into the PythonMessage - /// during its unpickling. Because unpickling also needs to - /// restore mailboxes, we also have to keep track of which - /// members should just be the mailbox object. - pub local_state: Vec, - + /// during its unpickling. + pub local_state: Vec, /// Tensors that will be mutated by the call. pub mutates: Vec, pub stream: StreamRef, diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index d38e774e..e1c528ac 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -314,7 +314,7 @@ impl WorkerMessageHandler for WorkerActor { async fn send_result_of_actor_call( &mut self, cx: &hyperactor::Context, - params: ActorCallParams, + _params: ActorCallParams, ) -> Result<()> { bail!("unimplemented: send_result_of_actor_call"); } diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index c7031ae1..7a5305c5 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -36,10 +36,11 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; -use monarch_hyperactor::actor::LocalPythonMessage; -use monarch_hyperactor::actor::PythonActor; use monarch_hyperactor::actor::PythonMessage; use monarch_hyperactor::actor::PythonMessageKind; +use monarch_hyperactor::local_state_broker::BrokerId; +use monarch_hyperactor::local_state_broker::LocalState; +use monarch_hyperactor::local_state_broker::LocalStateBrokerMessage; use monarch_hyperactor::mailbox::EitherPortRef; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; @@ -47,7 +48,6 @@ use monarch_messages::controller::WorkerError; use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; -use monarch_messages::worker::LocalState; use monarch_messages::worker::StreamRef; use monarch_types::PyTree; use monarch_types::SerializablePyErr; @@ -1681,40 +1681,33 @@ impl StreamMessageHandler for StreamActor { params: ActorCallParams, ) -> anyhow::Result<()> { // TODO: handle mutates - let actor_id = ActorId(cx.proc().proc_id().clone(), params.actor, params.index); - let actor_ref: ActorRef = ActorRef::attest(actor_id); - let actor_handle = actor_ref.downcast_handle(cx).unwrap(); + let local_state: Result> = Python::with_gil(|py| { + params + .local_state + .into_iter() + .map(|elem| { + // SAFETY: python is gonna make unsafe copies of this stuff anyway + unsafe { + let x = self.ref_to_rvalue(&elem)?.try_to_object_unsafe(py)?.into(); + Ok(x) + } + }) + .collect() + }); + let (send, recv) = cx.open_once_port(); let send = send.bind(); let send = EitherPortRef::Once(send.into()); - let message = PythonMessage::new_from_buf( - PythonMessageKind::CallMethod { - name: params.method, - response_port: Some(send), - }, - params.args_kwargs_tuple.into(), - ); - let local_state: Result> = - Python::with_gil(|py| { - params - .local_state - .into_iter() - .map(|elem| match elem { - LocalState::Mailbox => Ok(monarch_hyperactor::actor::LocalState::Mailbox), - // SAFETY: python is gonna make unsafe copies of this stuff anyway - LocalState::Ref(r) => unsafe { - let x = self.ref_to_rvalue(&r)?.try_to_object_unsafe(py)?.into(); - Ok(monarch_hyperactor::actor::LocalState::PyObject(x)) - }, - }) - .collect() - }); - // including making the PyMailbox - let message = LocalPythonMessage { - message, - local_state: Some(local_state?), + + let state = LocalState { + response_port: send, + state: local_state?, }; - actor_handle.send(message)?; + let x: u64 = params.seq.into(); + let message = LocalStateBrokerMessage::Set(x as usize, state); + + let broker = BrokerId::new(params.broker_id).resolve(cx).unwrap(); + broker.send(message)?; let result = recv.recv().await?; match result.kind { PythonMessageKind::Exception { .. } => { diff --git a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi index a256c556..e0ae3760 100644 --- a/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi +++ b/python/monarch/_rust_bindings/monarch_extension/mesh_controller.pyi @@ -40,3 +40,6 @@ class _Controller: to any future. """ ... + + @property + def broker_id(self) -> Tuple[str, int]: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi index c6cedd71..a306b9b4 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/actor.pyi @@ -7,8 +7,9 @@ # pyre-strict import abc +from enum import Enum -from typing import Any, final, List, Optional, Protocol, Type +from typing import Any, final, Iterable, List, Optional, Protocol, Tuple, Type from monarch._rust_bindings.monarch_hyperactor.mailbox import ( Mailbox, @@ -109,6 +110,9 @@ class PythonMessageKind: @classmethod @property def Uninit(cls) -> "Type[Uninit]": ... + @classmethod + @property + def CallMethodIndirect(cls) -> "Type[CallMethodIndirect]": ... class Result(PythonMessageKind): def __init__(self, rank: Optional[int]) -> None: ... @@ -129,6 +133,19 @@ class CallMethod(PythonMessageKind): @property def response_port(self) -> PortRef | OncePortRef | None: ... +class UnflattenArg(Enum): + Mailbox = 0 + PyObject = 1 + +class CallMethodIndirect(PythonMessageKind): + def __init__( + self, + name: str, + broker_id: Tuple[str, int], + id: int, + unflatten_args: List[UnflattenArg], + ) -> None: ... + class Init(PythonMessageKind): def __init__(self, response_port: PortRef | OncePortRef | None) -> None: ... @property @@ -210,5 +227,5 @@ class Actor(Protocol): shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: ... diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index 198dfc95..35ef6a4f 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -424,7 +424,7 @@ def _send( self._actor_mesh.cast(message, selection) else: importlib.import_module("monarch." + "mesh_controller").actor_send( - self, self._name, bytes, refs, port + self, bytes, refs, port, selection ) shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) @@ -720,11 +720,8 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: - if local_state is None: - local_state = itertools.repeat(mailbox) - match message.kind: case PythonMessageKind.CallMethod(response_port=response_port): pass diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index ca68a9f2..5ebfbebc 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -17,6 +17,7 @@ NamedTuple, Optional, Protocol, + Sequence, Tuple, TYPE_CHECKING, ) @@ -428,11 +429,8 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: class SendResultOfActorCall(NamedTuple): seq: int - actor: str - actor_index: int - method: str - args_kwargs_tuple: bytes - local_state: List[Referenceable | Mailbox] + broker_id: Tuple[str, int] + local_state: Sequence[Tensor | tensor_worker.Ref] mutates: List[tensor_worker.Ref] stream: tensor_worker.StreamRef diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index c69322f6..6545a926 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -30,11 +30,16 @@ WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.actor import ( + PythonMessage, + PythonMessageKind, + UnflattenArg, +) from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple +from monarch._src.actor.actor_mesh import ActorEndpoint, Port, PortTuple, Selection from monarch._src.actor.shape import NDSlice from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController @@ -264,12 +269,16 @@ def __str__(self): def actor_send( - actor_mesh: ActorMeshRef, - method: str, + endpoint: ActorEndpoint, args_kwargs_tuple: bytes, refs: Sequence[Any], port: Optional[Port[Any]], + selection: Selection, ): + unflatten_args = [ + UnflattenArg.PyObject if isinstance(ref, Tensor) else UnflattenArg.Mailbox + for ref in refs + ] tensors = [ref for ref in refs if isinstance(ref, Tensor)] # we have some monarch references, we need to ensure their # proc_mesh matches that of the tensors we sent to it @@ -284,8 +293,7 @@ def actor_send( # mutates checker.check_permission(()) selected_device_mesh = ( - actor_mesh._actor_mesh._proc_mesh - and actor_mesh._actor_mesh._proc_mesh._device_mesh + endpoint._actor_mesh._proc_mesh and endpoint._actor_mesh._proc_mesh._device_mesh ) if selected_device_mesh is not checker.mesh: raise ValueError( @@ -293,26 +301,33 @@ def actor_send( "NYI: better serialization of mesh names to make the mismatch more clear." ) - client = checker.mesh.client + client = cast(MeshClient, checker.mesh.client) + + broker_id: Tuple[str, int] = client._mesh_controller.broker_id + stream_ref = chosen_stream._to_ref(client) fut = (port, checker.mesh._ndslice) if port is not None else None ident = client.new_node([], tensors, cast("OldFuture", fut)) - actor_name, actor_index = actor_mesh._actor_mesh._name_pid - msg = SendResultOfActorCall( - ident, - actor_name, - actor_index, - method, + # To ensure that both the actor and the stream execute in order, we send a message + # to each at this point. The message to the worker will be handled on the stream actor where + # it will send the 'tensor's to the broker actor locally, along with a response port with the + # computed value. + + # The message to the generic actor tells it to first wait on the broker to get the local arguments + # from the stream, then it will run the actor method, and send the result to response port. + + actor_msg = PythonMessage( + PythonMessageKind.CallMethodIndirect( + endpoint._name, broker_id, ident, unflatten_args + ), args_kwargs_tuple, - cast("List[Mailbox | Referenceable]", refs), - [], - stream_ref, ) - - client.send(checker.mesh._ndslice, msg) + endpoint._actor_mesh.cast(actor_msg, selection) + worker_msg = SendResultOfActorCall(ident, broker_id, tensors, [], stream_ref) + client.send(checker.mesh._ndslice, worker_msg) # we have to ask for status updates # from workers to be sure they have finished # enough work to count this future as finished, diff --git a/python/tests/_monarch/test_actor_mesh.py b/python/tests/_monarch/test_actor_mesh.py index 874cf3c8..2ff1ce4e 100644 --- a/python/tests/_monarch/test_actor_mesh.py +++ b/python/tests/_monarch/test_actor_mesh.py @@ -7,7 +7,7 @@ # pyre-unsafe import pickle -from typing import Any, List +from typing import Any, Iterable, List import monarch import pytest @@ -49,7 +49,7 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None = None, + local_state: Iterable[Any] | None = None, ) -> None: assert rank is not None diff --git a/python/tests/_monarch/test_hyperactor.py b/python/tests/_monarch/test_hyperactor.py index 74eb262e..f8e2448e 100644 --- a/python/tests/_monarch/test_hyperactor.py +++ b/python/tests/_monarch/test_hyperactor.py @@ -10,7 +10,7 @@ import os import signal import time -from typing import Any, List +from typing import Any, Iterable import monarch @@ -34,7 +34,7 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: raise NotImplementedError() diff --git a/python/tests/_monarch/test_mailbox.py b/python/tests/_monarch/test_mailbox.py index cd1468f5..867d99b3 100644 --- a/python/tests/_monarch/test_mailbox.py +++ b/python/tests/_monarch/test_mailbox.py @@ -8,7 +8,7 @@ import asyncio import pickle -from typing import Any, Callable, cast, final, Generic, List, TYPE_CHECKING, TypeVar +from typing import Any, Callable, cast, final, Generic, Iterable, TYPE_CHECKING, TypeVar import monarch @@ -137,7 +137,7 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, - local_state: List[Any] | None, + local_state: Iterable[Any], ) -> None: call_method = cast("CallMethod", message.kind) assert call_method.response_port is not None diff --git a/python/tests/test_tensor_engine.py b/python/tests/test_tensor_engine.py index a79d5e8d..fe3dcc14 100644 --- a/python/tests/test_tensor_engine.py +++ b/python/tests/test_tensor_engine.py @@ -78,3 +78,29 @@ def test_actor_with_tensors() -> None: x = pm.spawn("adder", AddWithState, torch.ones(())).get() y = torch.ones(()) assert x.forward.call(y).get(timeout=5).item(hosts=0, gpus=0).item() == 2 + + +class Counter(Actor): + def __init__(self): + super().__init__() + self.c = 0 + + @endpoint + def incr(self, x) -> int: + self.c += 1 + return self.c - 1 + + +@two_gpu +def test_actor_tensor_ordering() -> None: + pm = proc_mesh(gpus=1).get() + with pm.activate(): + counter = pm.spawn("a", Counter).get() + results = [] + for _ in range(0, 10, 2): + # tensor engine call + results.append(counter.incr.call(torch.ones(()))) + # non-tensor engine call + results.append(counter.incr.call(1)) + + assert list(range(10)) == [r.get().item(hosts=0, gpus=0) for r in results]