From b0223655f0e44321c06fb3a4b459be3c7d5737c7 Mon Sep 17 00:00:00 2001 From: zdevito Date: Fri, 11 Jul 2025 16:06:00 -0700 Subject: [PATCH 1/2] [15/n] Pass monarch tensors to actor endpoints, part 1 This makes it possible to send a monarch tensor to an actor endpoint that is defiend over the same proc mesh as the tensor. The send is done locally so the actor can do work with the tensors. The stream actor in the tensor engine is sent a SendResultOfActorCall message, which it will forward to the actor, binding to the message the real local tensors that were passed as arguments. The stream actor that owns the tensor waits for called the actor to finish since the actor 'owns' the tensor through the duration of the call. Known limitation: this message to the actor can go out of order w.r.t to other messages sent from the owner of the tensor engine because the real message is being sent from the stream actor. The next PR will fix this limitation by sending _both_ the tensor engine and the actor a message at the same time. The actor will get a 'wait for SendResultOfActorCall' message, at which point it will stop processing any messages except for the SentResultOfActorCall message it is suppose to be waiting for. This way the correct order is preserved from the perspective of the tensor engine stream and the actor. Differential Revision: [D78196701](https://our.internmc.facebook.com/intern/diff/D78196701/) [ghstack-poisoned] --- monarch_extension/src/convert.rs | 44 +++++++++++++++ monarch_extension/src/tensor_worker.rs | 1 + monarch_hyperactor/src/actor.rs | 58 ++++++++++++++++--- monarch_hyperactor/src/mailbox.rs | 4 +- monarch_messages/Cargo.toml | 1 + monarch_messages/src/worker.rs | 46 +++++++++++++-- monarch_simulator/src/worker.rs | 8 +++ monarch_tensor_worker/src/lib.rs | 12 ++++ monarch_tensor_worker/src/stream.rs | 75 +++++++++++++++++++++++++ python/monarch/_src/actor/actor_mesh.py | 52 ++++++++++++++--- python/monarch/_src/actor/pickle.py | 30 ++-------- python/monarch/_src/actor/proc_mesh.py | 4 +- python/monarch/common/messages.py | 13 +++++ python/monarch/mesh_controller.py | 65 +++++++++++++++++++-- python/tests/test_tensor_engine.py | 21 ++++++- 15 files changed, 379 insertions(+), 55 deletions(-) diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index afda10ad..5aeb7075 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -176,6 +176,35 @@ impl<'a> MessageParser<'a> { fn parseWorkerMessageList(&self, name: &str) -> PyResult> { self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect() } + #[allow(non_snake_case)] + fn parseLocalStateList(&self, name: &str) -> PyResult> { + let local_state_list = self.attr(name)?; + let mut result = Vec::new(); + + // Get the PyMailbox class for type checking + let mailbox_class = self + .current + .py() + .import("monarch._rust_bindings.monarch_hyperactor.mailbox")? + .getattr("Mailbox")?; + + for item in local_state_list.try_iter()? { + let item = item?; + // Check if it's a Ref by trying to extract it + if let Ok(ref_val) = create_ref(item.clone()) { + result.push(worker::LocalState::Ref(ref_val)); + } else if item.is_instance(&mailbox_class)? { + // It's a PyMailbox instance + result.push(worker::LocalState::Mailbox); + } else { + return Err(PyValueError::new_err(format!( + "Expected Ref or Mailbox in local_state, got: {}", + item.get_type().name()? + ))); + } + } + Ok(result) + } fn parse_error_reason(&self, name: &str) -> PyResult, String)>> { let err = self.attr(name)?; if err.is_none() { @@ -432,6 +461,21 @@ fn create_map(py: Python) -> HashMap { p.parseWorkerMessageList("commands")?, )) }); + // HELP DEVMATE! + // Add a cast for SendResultOfActorCall + m.insert(key("SendResultOfActorCall"), |p| { + Ok(WorkerMessage::SendResultOfActorCall( + worker::ActorCallParams { + seq: p.parseSeq("seq")?, + actor: p.parse("actor")?, + index: p.parse("actor_index")?, + python_message: p.parse("python_message")?, + local_state: p.parseLocalStateList("local_state")?, + mutates: p.parseRefList("mutates")?, + stream: p.parseStreamRef("stream")?, + }, + )) + }); m } diff --git a/monarch_extension/src/tensor_worker.rs b/monarch_extension/src/tensor_worker.rs index 31a9ac4d..b11a3c67 100644 --- a/monarch_extension/src/tensor_worker.rs +++ b/monarch_extension/src/tensor_worker.rs @@ -1372,6 +1372,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P } WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(), WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(), + WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(), } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 3a50f95a..1548d25b 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -175,10 +175,10 @@ impl PickledMessageClientActor { #[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)] pub struct PythonMessage { - pub(crate) method: String, - pub(crate) message: ByteBuf, - response_port: Option, - rank: Option, + pub method: String, + pub message: ByteBuf, + pub response_port: Option, + pub rank: Option, } impl PythonMessage { @@ -289,7 +289,7 @@ impl PythonActorHandle { PythonMessage { cast = true }, ], )] -pub(super) struct PythonActor { +pub struct PythonActor { /// The Python object that we delegate message handling to. An instance of /// `monarch.actor_mesh._Actor`. pub(super) actor: PyObject, @@ -399,9 +399,13 @@ impl PanicFlag { } #[async_trait] -impl Handler for PythonActor { - async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { - let mailbox = PyMailbox { +impl Handler for PythonActor { + async fn handle( + &mut self, + cx: &Context, + message: LocalPythonMessage, + ) -> anyhow::Result<()> { + let mailbox: PyMailbox = PyMailbox { inner: cx.mailbox_for_py().clone(), }; // Create a channel for signaling panics in async endpoints. @@ -409,6 +413,16 @@ impl Handler for PythonActor { let (sender, receiver) = oneshot::channel(); let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> { + let mailbox = mailbox.into_bound_py_any(py).unwrap(); + let local_state: Option>> = message.local_state.map(|state| { + state + .into_iter() + .map(|e| match e { + LocalState::Mailbox => mailbox.clone(), + LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(), + }) + .collect() + }); let (rank, shape) = cx.cast_info(); let awaitable = self.actor.call_method( py, @@ -417,10 +431,11 @@ impl Handler for PythonActor { mailbox, rank, PyShape::from(shape), - message, + message.message, PanicFlag { sender: Some(sender), }, + local_state, ), None, )?; @@ -439,6 +454,20 @@ impl Handler for PythonActor { } } +#[async_trait] +impl Handler for PythonActor { + async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { + self.handle( + cx, + LocalPythonMessage { + message, + local_state: None, + }, + ) + .await + } +} + /// Helper struct to make a Python future passable in an actor message. /// /// Also so that we don't have to write this massive type signature everywhere @@ -538,6 +567,17 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask { Ok(()) } } +#[derive(Debug)] +pub enum LocalState { + Mailbox, + PyObject(PyObject), +} + +#[derive(Debug)] +pub struct LocalPythonMessage { + pub message: PythonMessage, + pub local_state: Option>, +} pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 38bc5997..f8321b1f 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -313,7 +313,7 @@ impl PythonUndeliverablePortHandle { name = "PortRef", module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] -pub(super) struct PythonPortRef { +pub struct PythonPortRef { pub(crate) inner: PortRef, } @@ -453,7 +453,7 @@ impl PythonOncePortHandle { name = "OncePortRef", module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] -pub(crate) struct PythonOncePortRef { +pub struct PythonOncePortRef { pub(crate) inner: Option>, } diff --git a/monarch_messages/Cargo.toml b/monarch_messages/Cargo.toml index 13e7d57e..bbd8ab5a 100644 --- a/monarch_messages/Cargo.toml +++ b/monarch_messages/Cargo.toml @@ -12,6 +12,7 @@ anyhow = "1.0.98" derive_more = { version = "1.0.0", features = ["full"] } enum-as-inner = "0.6.0" hyperactor = { version = "0.0.0", path = "../hyperactor" } +monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" } monarch_types = { version = "0.0.0", path = "../monarch_types" } ndslice = { version = "0.0.0", path = "../ndslice" } pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] } diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 81e5737e..07dfcd2c 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -27,6 +27,7 @@ use hyperactor::Named; use hyperactor::RefClient; use hyperactor::Unbind; use hyperactor::reference::ActorId; +use monarch_hyperactor::actor::PythonMessage; use monarch_types::SerializablePyErr; use ndslice::Slice; use pyo3::exceptions::PyValueError; @@ -402,6 +403,33 @@ pub struct CallFunctionParams { pub remote_process_groups: Vec, } +/// The local state that has to be restored into the python_message during +/// its unpickling. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum LocalState { + Ref(Ref), + Mailbox, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActorCallParams { + pub seq: Seq, + /// The actor to call is (proc_id of stream worker, 'actor', 'index') + pub actor: String, + pub index: usize, + /// The PythonMessage object to send to the actor. + pub python_message: PythonMessage, + /// Referenceable objects to pass to the actor, + /// these will be put into the PythonMessage + /// during its unpickling. Because unpickling also needs to + /// restore mailboxes, we also have to keep track of which + /// members should just be the mailbox object. + pub local_state: Vec, + + /// Tensors that will be mutated by the call. + pub mutates: Vec, + pub stream: StreamRef, +} /// Type of reduction for [`WorkerMessage::Reduce`]. #[derive(Debug, Clone, Serialize, Deserialize)] pub enum Reduction { @@ -642,21 +670,30 @@ pub enum WorkerMessage { /// First use of the borrow on the receiving stream. This is a marker for /// synchronization. - BorrowFirstUse { borrow: u64 }, + BorrowFirstUse { + borrow: u64, + }, /// Last use of the borrow on the receiving stream. This is a marker for /// synchronization. - BorrowLastUse { borrow: u64 }, + BorrowLastUse { + borrow: u64, + }, /// Drop the borrow and free the resources associated with it. - BorrowDrop { borrow: u64 }, + BorrowDrop { + borrow: u64, + }, /// Delete these refs from the worker state. DeleteRefs(Vec), /// A [`ControllerMessage::Status`] will be send to the controller /// when all streams have processed all the message sent before this one. - RequestStatus { seq: Seq, controller: bool }, + RequestStatus { + seq: Seq, + controller: bool, + }, /// Perform a reduction operation, using an efficient communication backend. /// Only NCCL is supported for now. @@ -758,6 +795,7 @@ pub enum WorkerMessage { stream: StreamRef, }, + SendResultOfActorCall(ActorCallParams), PipeRecv { seq: Seq, /// Result refs. diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index 4d662e2d..14eb0cdc 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -312,6 +312,14 @@ impl WorkerMessageHandler for WorkerActor { Ok(()) } + async fn send_result_of_actor_call( + &mut self, + cx: &hyperactor::Context, + params: ActorCallParams, + ) -> Result<()> { + bail!("unimplemented: send_result_of_actor_call"); + } + async fn command_group( &mut self, cx: &hyperactor::Context, diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 1ae778e3..7233f680 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -72,6 +72,7 @@ use monarch_messages::controller::ControllerActor; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::wire_value::WireValue; +use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; use monarch_messages::worker::Factory; @@ -860,6 +861,17 @@ impl WorkerMessageHandler for WorkerActor { .await } + async fn send_result_of_actor_call( + &mut self, + cx: &hyperactor::Context, + params: ActorCallParams, + ) -> Result<()> { + let stream = self.try_get_stream(params.stream)?; + stream + .send_result_of_actor_call(cx, cx.self_id().clone(), params) + .await?; + Ok(()) + } async fn split_comm( &mut self, cx: &hyperactor::Context, diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 1803035f..47c5e382 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -36,12 +36,18 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; +use monarch_hyperactor::actor::LocalPythonMessage; +use monarch_hyperactor::actor::PythonActor; use monarch_hyperactor::actor::PythonMessage; +use monarch_hyperactor::mailbox::EitherPortRef; +use monarch_hyperactor::mailbox::PythonPortRef; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; +use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; +use monarch_messages::worker::LocalState; use monarch_messages::worker::StreamRef; use monarch_types::PyTree; use monarch_types::SerializablePyErr; @@ -233,6 +239,8 @@ pub enum StreamMessage { ), GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle>), + + SendResultOfActorCall(ActorId, ActorCallParams), } impl StreamMessage { @@ -1666,6 +1674,73 @@ impl StreamMessageHandler for StreamActor { Ok(()) } + async fn send_result_of_actor_call( + &mut self, + cx: &Context, + worker_actor_id: ActorId, + params: ActorCallParams, + ) -> anyhow::Result<()> { + // TODO: handle mutates + let actor_id = ActorId(cx.proc().proc_id().clone(), params.actor, params.index); + let actor_ref: ActorRef = ActorRef::attest(actor_id); + let actor_handle = actor_ref.downcast_handle(cx).unwrap(); + let (send, recv) = cx.open_once_port(); + let send = send.bind(); + let send = EitherPortRef::Once(send.into()); + let message = PythonMessage { + response_port: Some(send), + ..params.python_message + }; + let local_state: Result> = + Python::with_gil(|py| { + params + .local_state + .into_iter() + .map(|elem| match elem { + LocalState::Mailbox => Ok(monarch_hyperactor::actor::LocalState::Mailbox), + // SAFETY: python is gonna make unsafe copies of this stuff anyway + LocalState::Ref(r) => unsafe { + let x = self.ref_to_rvalue(&r)?.try_to_object_unsafe(py)?.into(); + Ok(monarch_hyperactor::actor::LocalState::PyObject(x)) + }, + }) + .collect() + }); + // including making the PyMailbox + let message = LocalPythonMessage { + message, + local_state: Some(local_state?), + }; + actor_handle.send(message)?; + let result = recv.recv().await?.with_rank(worker_actor_id.rank()); + if result.method == "exception" { + // If result has "exception" as its kind, then + // we need to unpickle and turn it into a WorkerError + // and call remote_function_failed otherwise the + // controller assumes the object is correct and doesn't handle + // dependency tracking correctly. + let err = Python::with_gil(|py| -> Result { + let err = py + .import("pickle") + .unwrap() + .call_method1("loads", (result.message.into_vec(),))?; + Ok(WorkerError { + worker_actor_id, + backtrace: err.to_string(), + }) + })?; + self.controller_actor + .remote_function_failed(cx, params.seq, err) + .await?; + } else { + let result = Serialized::serialize(&result).unwrap(); + self.controller_actor + .fetch_result(cx, params.seq, Ok(result)) + .await?; + } + Ok(()) + } + async fn set_value( &mut self, cx: &Context, diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index e16261be..5f155db1 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -9,7 +9,9 @@ import collections import contextvars import functools +import importlib import inspect +import itertools import logging import random import traceback @@ -60,10 +62,13 @@ from monarch._src.actor.future import Future from monarch._src.actor.pdb_wrapper import PdbWrapper -from monarch._src.actor.pickle import flatten, unpickle +from monarch._src.actor.pickle import flatten, unflatten from monarch._src.actor.shape import MeshTrait, NDSlice +if TYPE_CHECKING: + from monarch._src.actor.proc_mesh import ProcMesh + logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -137,36 +142,41 @@ def __init__( self, mailbox: Mailbox, hy_actor_mesh: Optional[PythonActorMesh], + proc_mesh: "Optional[ProcMesh]", shape: Shape, actor_ids: List[ActorId], ) -> None: self._mailbox = mailbox self._actor_mesh = hy_actor_mesh + # actor meshes do not have a way to look this up at the moment, + # so we fake it here + self._proc_mesh = proc_mesh self._shape = shape self._please_replace_me_actor_ids = actor_ids @staticmethod def from_hyperactor_mesh( - mailbox: Mailbox, hy_actor_mesh: PythonActorMesh + mailbox: Mailbox, hy_actor_mesh: PythonActorMesh, proc_mesh: "ProcMesh" ) -> "_ActorMeshRefImpl": shape: Shape = hy_actor_mesh.shape return _ActorMeshRefImpl( mailbox, hy_actor_mesh, + proc_mesh, hy_actor_mesh.shape, [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], ) @staticmethod def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": - return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id]) + return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id]) @staticmethod def from_actor_ref_with_shape( ref: "_ActorMeshRefImpl", shape: Shape ) -> "_ActorMeshRefImpl": return _ActorMeshRefImpl( - ref._mailbox, None, shape, ref._please_replace_me_actor_ids + ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids ) def __getstate__( @@ -222,6 +232,11 @@ def cast( def __len__(self) -> int: return len(self._shape) + @property + def _name_pid(self): + actor_id0 = self._please_replace_me_actor_ids[0] + return actor_id0.actor_name, actor_id0.pid + class Extent(NamedTuple): labels: Sequence[str] @@ -367,13 +382,21 @@ def _send( This sends the message to all actors but does not wait for any result. """ self._signature.bind(None, *args, **kwargs) + + objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox) message = PythonMessage( self._name, - _pickle((args, kwargs)), + bytes, None if port is None else port._port_ref, None, ) - self._actor_mesh.cast(message, selection) + refs = [obj for obj in objects if hasattr(obj, "__monarch_ref__")] + if not refs: + self._actor_mesh.cast(message, selection) + else: + importlib.import_module("monarch." + "mesh_controller").actor_send( + self, message, refs, port + ) shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) @@ -569,7 +592,7 @@ def _blocking_recv(self) -> R: def _process(self, msg: PythonMessage) -> R: # TODO: Try to do something more structured than a cast here - payload = cast(R, unpickle(msg.message, self._mailbox)) + payload = cast(R, unflatten(msg.message, itertools.repeat(self._mailbox))) if msg.method == "result": return payload else: @@ -616,7 +639,10 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, + local_state: List[Any] | None, ) -> None: + if local_state is None: + local_state = itertools.repeat(mailbox) port = ( Port(message.response_port, mailbox, rank) if message.response_port @@ -630,11 +656,13 @@ async def handle( DebugContext.set(DebugContext()) - args, kwargs = unpickle(message.message, mailbox) + args, kwargs = unflatten(message.message, local_state) if message.method == "__init__": Class, *args = args self.instance = Class(*args, **kwargs) + if port is not None: + port.send("result", None) return None if self.instance is None: @@ -735,9 +763,17 @@ def _post_mortem_debug(self, exc_tb) -> None: def _is_mailbox(x: object) -> bool: + if hasattr(x, "__monarch_ref__"): + raise NotImplementedError( + "Sending monarch tensor references directly to a port." + ) return isinstance(x, Mailbox) +def _is_ref_or_mailbox(x: object) -> bool: + return hasattr(x, "__monarch_ref__") or isinstance(x, Mailbox) + + def _pickle(obj: object) -> bytes: _, msg = flatten(obj, _is_mailbox) return msg diff --git a/python/monarch/_src/actor/pickle.py b/python/monarch/_src/actor/pickle.py index 5b2167df..62268d07 100644 --- a/python/monarch/_src/actor/pickle.py +++ b/python/monarch/_src/actor/pickle.py @@ -5,9 +5,8 @@ # LICENSE file in the root directory of this source tree. import io -import itertools import pickle -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager, ExitStack from typing import Any, Callable, Iterable, List, Tuple import cloudpickle @@ -51,12 +50,10 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]: def unflatten(data: bytes, values: Iterable[Any]) -> Any: - if torch is not None: - context_manager = torch.utils._python_dispatch._disable_current_modes - else: - context_manager = nullcontext - - with context_manager(): + with ExitStack() as stack: + if torch is not None: + stack.enter_context(load_tensors_on_cpu()) + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) up = _Unpickler(data, values) return up.load() @@ -73,20 +70,3 @@ def load_tensors_on_cpu(): yield finally: torch.storage._load_from_bytes = old - - -def unpickle(data: bytes, mailbox) -> Any: - if torch is not None: - context_manager = load_tensors_on_cpu - else: - context_manager = nullcontext - - with context_manager(): - # regardless of the mailboxes of the remote objects - # they all become the local mailbox. - return unflatten(data, itertools.repeat(mailbox)) - - -def pickle_(obj: object, filter: Callable[[Any], bool]) -> bytes: - _, msg = flatten(obj, filter) - return msg diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 3195dc67..81465216 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -172,7 +172,7 @@ def _spawn_blocking( actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor) service = ActorMeshRef( Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self), self._mailbox, ) # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by @@ -197,7 +197,7 @@ async def _spawn_nonblocking( actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor) service = ActorMeshRef( Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self), self._mailbox, ) # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index 5e600912..13062fb6 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -19,9 +19,12 @@ Protocol, Tuple, TYPE_CHECKING, + Union, ) from monarch._rust_bindings.monarch_extension import tensor_worker +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._src.actor.shape import NDSlice from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction @@ -425,6 +428,16 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: ) +class SendResultOfActorCall(NamedTuple): + seq: int + actor: str + actor_index: int + python_message: PythonMessage + local_state: List[Referenceable | Mailbox] + mutates: List[tensor_worker.Ref] + stream: tensor_worker.StreamRef + + class SplitComm(NamedTuple): dims: Dims device_mesh: DeviceMesh diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 2b7973d3..900f0c44 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -30,18 +30,19 @@ WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import Port, PortTuple +from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple from monarch._src.actor.shape import NDSlice -from monarch.common import messages +from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController from monarch.common.invocation import Seq +from monarch.common.messages import Referenceable, SendResultOfActorCall from monarch.common.stream import StreamRef -from monarch.common.tensor import Tensor - +from monarch.common.tensor import InputChecker, Tensor from monarch.tensor_worker_main import _set_trace if TYPE_CHECKING: @@ -255,3 +256,59 @@ def __str__(self): except Exception: traceback.print_exc() return "" + + +def actor_send( + actor_mesh: ActorMeshRef, + message: PythonMessage, + refs: Sequence[Any], + port: Optional[Port[Any]], +): + tensors = [ref for ref in refs if isinstance(ref, Tensor)] + # we have some monarch references, we need to ensure their + # proc_mesh matches that of the tensors we sent to it + chosen_stream = stream._active + for t in tensors: + if hasattr(t, "stream"): + chosen_stream = t.stream + break + with InputChecker(refs, lambda x: f"actor_call({x})") as checker: + checker.check_mesh_stream_local(device_mesh._active, chosen_stream) + # TODO: move propagators into Endpoint abstraction and run the propagator to get the + # mutates + checker.check_permission(()) + selected_device_mesh = ( + actor_mesh._actor_mesh._proc_mesh + and actor_mesh._actor_mesh._proc_mesh._device_mesh + ) + if selected_device_mesh is not checker.mesh: + raise ValueError( + f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}." + "NYI: better serialization of mesh names to make the mismatch more clear." + ) + + client = checker.mesh.client + stream_ref = chosen_stream._to_ref(client) + + fut = (port, checker.mesh._ndslice) if port is not None else None + + ident = client.new_node([], tensors, cast("OldFuture", fut)) + + actor_name, actor_index = actor_mesh._actor_mesh._name_pid + msg = SendResultOfActorCall( + ident, + actor_name, + actor_index, + message, + cast("List[Mailbox | Referenceable]", refs), + [], + stream_ref, + ) + + print("SENDING MESSAGE", msg) + client.send(checker.mesh._ndslice, msg) + # we have to ask for status updates + # from workers to be sure they have finished + # enough work to count this future as finished, + # and all potential errors have been reported + client._request_status() diff --git a/python/tests/test_tensor_engine.py b/python/tests/test_tensor_engine.py index a7e5122b..a79d5e8d 100644 --- a/python/tests/test_tensor_engine.py +++ b/python/tests/test_tensor_engine.py @@ -8,7 +8,7 @@ import pytest import torch from monarch import remote -from monarch.actor import proc_mesh +from monarch.actor import Actor, endpoint, proc_mesh from monarch.mesh_controller import spawn_tensor_engine @@ -59,3 +59,22 @@ def test_proc_mesh_tensor_engine() -> None: assert a == 0 assert b == 10 assert c == 100 + + +class AddWithState(Actor): + def __init__(self, state: torch.Tensor): + super().__init__() + self.state = state + + @endpoint + def forward(self, x) -> torch.Tensor: + return x + self.state + + +@two_gpu +def test_actor_with_tensors() -> None: + pm = proc_mesh(gpus=1).get() + with pm.activate(): + x = pm.spawn("adder", AddWithState, torch.ones(())).get() + y = torch.ones(()) + assert x.forward.call(y).get(timeout=5).item(hosts=0, gpus=0).item() == 2 From 3aed803272c8a2d2acfc132486c4fb5f391a3efc Mon Sep 17 00:00:00 2001 From: zdevito Date: Fri, 11 Jul 2025 16:08:08 -0700 Subject: [PATCH 2/2] Update on "[15/n] Pass monarch tensors to actor endpoints, part 1" This makes it possible to send a monarch tensor to an actor endpoint that is defiend over the same proc mesh as the tensor. The send is done locally so the actor can do work with the tensors. The stream actor in the tensor engine is sent a SendResultOfActorCall message, which it will forward to the actor, binding to the message the real local tensors that were passed as arguments. The stream actor that owns the tensor waits for called the actor to finish since the actor 'owns' the tensor through the duration of the call. Known limitation: this message to the actor can go out of order w.r.t to other messages sent from the owner of the tensor engine because the real message is being sent from the stream actor. The next PR will fix this limitation by sending _both_ the tensor engine and the actor a message at the same time. The actor will get a 'wait for SendResultOfActorCall' message, at which point it will stop processing any messages except for the SentResultOfActorCall message it is suppose to be waiting for. This way the correct order is preserved from the perspective of the tensor engine stream and the actor. Differential Revision: [D78196701](https://our.internmc.facebook.com/intern/diff/D78196701/) [ghstack-poisoned] --- monarch_extension/src/convert.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index 5aeb7075..ecce7dc3 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -461,8 +461,6 @@ fn create_map(py: Python) -> HashMap { p.parseWorkerMessageList("commands")?, )) }); - // HELP DEVMATE! - // Add a cast for SendResultOfActorCall m.insert(key("SendResultOfActorCall"), |p| { Ok(WorkerMessage::SendResultOfActorCall( worker::ActorCallParams {