diff --git a/monarch_extension/src/convert.rs b/monarch_extension/src/convert.rs index afda10ad..ecce7dc3 100644 --- a/monarch_extension/src/convert.rs +++ b/monarch_extension/src/convert.rs @@ -176,6 +176,35 @@ impl<'a> MessageParser<'a> { fn parseWorkerMessageList(&self, name: &str) -> PyResult> { self.attr(name)?.try_iter()?.map(|x| convert(x?)).collect() } + #[allow(non_snake_case)] + fn parseLocalStateList(&self, name: &str) -> PyResult> { + let local_state_list = self.attr(name)?; + let mut result = Vec::new(); + + // Get the PyMailbox class for type checking + let mailbox_class = self + .current + .py() + .import("monarch._rust_bindings.monarch_hyperactor.mailbox")? + .getattr("Mailbox")?; + + for item in local_state_list.try_iter()? { + let item = item?; + // Check if it's a Ref by trying to extract it + if let Ok(ref_val) = create_ref(item.clone()) { + result.push(worker::LocalState::Ref(ref_val)); + } else if item.is_instance(&mailbox_class)? { + // It's a PyMailbox instance + result.push(worker::LocalState::Mailbox); + } else { + return Err(PyValueError::new_err(format!( + "Expected Ref or Mailbox in local_state, got: {}", + item.get_type().name()? + ))); + } + } + Ok(result) + } fn parse_error_reason(&self, name: &str) -> PyResult, String)>> { let err = self.attr(name)?; if err.is_none() { @@ -432,6 +461,19 @@ fn create_map(py: Python) -> HashMap { p.parseWorkerMessageList("commands")?, )) }); + m.insert(key("SendResultOfActorCall"), |p| { + Ok(WorkerMessage::SendResultOfActorCall( + worker::ActorCallParams { + seq: p.parseSeq("seq")?, + actor: p.parse("actor")?, + index: p.parse("actor_index")?, + python_message: p.parse("python_message")?, + local_state: p.parseLocalStateList("local_state")?, + mutates: p.parseRefList("mutates")?, + stream: p.parseStreamRef("stream")?, + }, + )) + }); m } diff --git a/monarch_extension/src/tensor_worker.rs b/monarch_extension/src/tensor_worker.rs index 31a9ac4d..b11a3c67 100644 --- a/monarch_extension/src/tensor_worker.rs +++ b/monarch_extension/src/tensor_worker.rs @@ -1372,6 +1372,7 @@ pub(crate) fn worker_message_to_py(py: Python<'_>, message: &WorkerMessage) -> P } WorkerMessage::SetRefUnitTestsOnly { .. } => unimplemented!(), WorkerMessage::GetRefUnitTestsOnly { .. } => unimplemented!(), + WorkerMessage::SendResultOfActorCall { .. } => unimplemented!(), } } diff --git a/monarch_hyperactor/src/actor.rs b/monarch_hyperactor/src/actor.rs index 3a50f95a..1548d25b 100644 --- a/monarch_hyperactor/src/actor.rs +++ b/monarch_hyperactor/src/actor.rs @@ -175,10 +175,10 @@ impl PickledMessageClientActor { #[pyclass(frozen, module = "monarch._rust_bindings.monarch_hyperactor.actor")] #[derive(Default, Clone, Serialize, Deserialize, Named, PartialEq)] pub struct PythonMessage { - pub(crate) method: String, - pub(crate) message: ByteBuf, - response_port: Option, - rank: Option, + pub method: String, + pub message: ByteBuf, + pub response_port: Option, + pub rank: Option, } impl PythonMessage { @@ -289,7 +289,7 @@ impl PythonActorHandle { PythonMessage { cast = true }, ], )] -pub(super) struct PythonActor { +pub struct PythonActor { /// The Python object that we delegate message handling to. An instance of /// `monarch.actor_mesh._Actor`. pub(super) actor: PyObject, @@ -399,9 +399,13 @@ impl PanicFlag { } #[async_trait] -impl Handler for PythonActor { - async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { - let mailbox = PyMailbox { +impl Handler for PythonActor { + async fn handle( + &mut self, + cx: &Context, + message: LocalPythonMessage, + ) -> anyhow::Result<()> { + let mailbox: PyMailbox = PyMailbox { inner: cx.mailbox_for_py().clone(), }; // Create a channel for signaling panics in async endpoints. @@ -409,6 +413,16 @@ impl Handler for PythonActor { let (sender, receiver) = oneshot::channel(); let future = Python::with_gil(|py| -> Result<_, SerializablePyErr> { + let mailbox = mailbox.into_bound_py_any(py).unwrap(); + let local_state: Option>> = message.local_state.map(|state| { + state + .into_iter() + .map(|e| match e { + LocalState::Mailbox => mailbox.clone(), + LocalState::PyObject(obj) => obj.into_bound_py_any(py).unwrap(), + }) + .collect() + }); let (rank, shape) = cx.cast_info(); let awaitable = self.actor.call_method( py, @@ -417,10 +431,11 @@ impl Handler for PythonActor { mailbox, rank, PyShape::from(shape), - message, + message.message, PanicFlag { sender: Some(sender), }, + local_state, ), None, )?; @@ -439,6 +454,20 @@ impl Handler for PythonActor { } } +#[async_trait] +impl Handler for PythonActor { + async fn handle(&mut self, cx: &Context, message: PythonMessage) -> anyhow::Result<()> { + self.handle( + cx, + LocalPythonMessage { + message, + local_state: None, + }, + ) + .await + } +} + /// Helper struct to make a Python future passable in an actor message. /// /// Also so that we don't have to write this massive type signature everywhere @@ -538,6 +567,17 @@ impl AsyncEndpointInvocationHandler for AsyncEndpointTask { Ok(()) } } +#[derive(Debug)] +pub enum LocalState { + Mailbox, + PyObject(PyObject), +} + +#[derive(Debug)] +pub struct LocalPythonMessage { + pub message: PythonMessage, + pub local_state: Option>, +} pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResult<()> { hyperactor_mod.add_class::()?; diff --git a/monarch_hyperactor/src/mailbox.rs b/monarch_hyperactor/src/mailbox.rs index 38bc5997..f8321b1f 100644 --- a/monarch_hyperactor/src/mailbox.rs +++ b/monarch_hyperactor/src/mailbox.rs @@ -313,7 +313,7 @@ impl PythonUndeliverablePortHandle { name = "PortRef", module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] -pub(super) struct PythonPortRef { +pub struct PythonPortRef { pub(crate) inner: PortRef, } @@ -453,7 +453,7 @@ impl PythonOncePortHandle { name = "OncePortRef", module = "monarch._rust_bindings.monarch_hyperactor.mailbox" )] -pub(crate) struct PythonOncePortRef { +pub struct PythonOncePortRef { pub(crate) inner: Option>, } diff --git a/monarch_messages/Cargo.toml b/monarch_messages/Cargo.toml index 13e7d57e..bbd8ab5a 100644 --- a/monarch_messages/Cargo.toml +++ b/monarch_messages/Cargo.toml @@ -12,6 +12,7 @@ anyhow = "1.0.98" derive_more = { version = "1.0.0", features = ["full"] } enum-as-inner = "0.6.0" hyperactor = { version = "0.0.0", path = "../hyperactor" } +monarch_hyperactor = { version = "0.0.0", path = "../monarch_hyperactor" } monarch_types = { version = "0.0.0", path = "../monarch_types" } ndslice = { version = "0.0.0", path = "../ndslice" } pyo3 = { version = "0.24", features = ["anyhow", "multiple-pymethods"] } diff --git a/monarch_messages/src/worker.rs b/monarch_messages/src/worker.rs index 81e5737e..07dfcd2c 100644 --- a/monarch_messages/src/worker.rs +++ b/monarch_messages/src/worker.rs @@ -27,6 +27,7 @@ use hyperactor::Named; use hyperactor::RefClient; use hyperactor::Unbind; use hyperactor::reference::ActorId; +use monarch_hyperactor::actor::PythonMessage; use monarch_types::SerializablePyErr; use ndslice::Slice; use pyo3::exceptions::PyValueError; @@ -402,6 +403,33 @@ pub struct CallFunctionParams { pub remote_process_groups: Vec, } +/// The local state that has to be restored into the python_message during +/// its unpickling. +#[derive(Serialize, Deserialize, Debug, Clone)] +pub enum LocalState { + Ref(Ref), + Mailbox, +} + +#[derive(Serialize, Deserialize, Debug, Clone)] +pub struct ActorCallParams { + pub seq: Seq, + /// The actor to call is (proc_id of stream worker, 'actor', 'index') + pub actor: String, + pub index: usize, + /// The PythonMessage object to send to the actor. + pub python_message: PythonMessage, + /// Referenceable objects to pass to the actor, + /// these will be put into the PythonMessage + /// during its unpickling. Because unpickling also needs to + /// restore mailboxes, we also have to keep track of which + /// members should just be the mailbox object. + pub local_state: Vec, + + /// Tensors that will be mutated by the call. + pub mutates: Vec, + pub stream: StreamRef, +} /// Type of reduction for [`WorkerMessage::Reduce`]. #[derive(Debug, Clone, Serialize, Deserialize)] pub enum Reduction { @@ -642,21 +670,30 @@ pub enum WorkerMessage { /// First use of the borrow on the receiving stream. This is a marker for /// synchronization. - BorrowFirstUse { borrow: u64 }, + BorrowFirstUse { + borrow: u64, + }, /// Last use of the borrow on the receiving stream. This is a marker for /// synchronization. - BorrowLastUse { borrow: u64 }, + BorrowLastUse { + borrow: u64, + }, /// Drop the borrow and free the resources associated with it. - BorrowDrop { borrow: u64 }, + BorrowDrop { + borrow: u64, + }, /// Delete these refs from the worker state. DeleteRefs(Vec), /// A [`ControllerMessage::Status`] will be send to the controller /// when all streams have processed all the message sent before this one. - RequestStatus { seq: Seq, controller: bool }, + RequestStatus { + seq: Seq, + controller: bool, + }, /// Perform a reduction operation, using an efficient communication backend. /// Only NCCL is supported for now. @@ -758,6 +795,7 @@ pub enum WorkerMessage { stream: StreamRef, }, + SendResultOfActorCall(ActorCallParams), PipeRecv { seq: Seq, /// Result refs. diff --git a/monarch_simulator/src/worker.rs b/monarch_simulator/src/worker.rs index 4d662e2d..14eb0cdc 100644 --- a/monarch_simulator/src/worker.rs +++ b/monarch_simulator/src/worker.rs @@ -312,6 +312,14 @@ impl WorkerMessageHandler for WorkerActor { Ok(()) } + async fn send_result_of_actor_call( + &mut self, + cx: &hyperactor::Context, + params: ActorCallParams, + ) -> Result<()> { + bail!("unimplemented: send_result_of_actor_call"); + } + async fn command_group( &mut self, cx: &hyperactor::Context, diff --git a/monarch_tensor_worker/src/lib.rs b/monarch_tensor_worker/src/lib.rs index 1ae778e3..7233f680 100644 --- a/monarch_tensor_worker/src/lib.rs +++ b/monarch_tensor_worker/src/lib.rs @@ -72,6 +72,7 @@ use monarch_messages::controller::ControllerActor; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::wire_value::WireValue; +use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; use monarch_messages::worker::Factory; @@ -860,6 +861,17 @@ impl WorkerMessageHandler for WorkerActor { .await } + async fn send_result_of_actor_call( + &mut self, + cx: &hyperactor::Context, + params: ActorCallParams, + ) -> Result<()> { + let stream = self.try_get_stream(params.stream)?; + stream + .send_result_of_actor_call(cx, cx.self_id().clone(), params) + .await?; + Ok(()) + } async fn split_comm( &mut self, cx: &hyperactor::Context, diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index 1803035f..47c5e382 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -36,12 +36,18 @@ use hyperactor::mailbox::Mailbox; use hyperactor::mailbox::OncePortHandle; use hyperactor::mailbox::PortReceiver; use hyperactor::proc::Proc; +use monarch_hyperactor::actor::LocalPythonMessage; +use monarch_hyperactor::actor::PythonActor; use monarch_hyperactor::actor::PythonMessage; +use monarch_hyperactor::mailbox::EitherPortRef; +use monarch_hyperactor::mailbox::PythonPortRef; use monarch_messages::controller::ControllerMessageClient; use monarch_messages::controller::Seq; use monarch_messages::controller::WorkerError; +use monarch_messages::worker::ActorCallParams; use monarch_messages::worker::CallFunctionError; use monarch_messages::worker::CallFunctionParams; +use monarch_messages::worker::LocalState; use monarch_messages::worker::StreamRef; use monarch_types::PyTree; use monarch_types::SerializablePyErr; @@ -233,6 +239,8 @@ pub enum StreamMessage { ), GetTensorRefUnitTestsOnly(Ref, #[reply] OncePortHandle>), + + SendResultOfActorCall(ActorId, ActorCallParams), } impl StreamMessage { @@ -1666,6 +1674,73 @@ impl StreamMessageHandler for StreamActor { Ok(()) } + async fn send_result_of_actor_call( + &mut self, + cx: &Context, + worker_actor_id: ActorId, + params: ActorCallParams, + ) -> anyhow::Result<()> { + // TODO: handle mutates + let actor_id = ActorId(cx.proc().proc_id().clone(), params.actor, params.index); + let actor_ref: ActorRef = ActorRef::attest(actor_id); + let actor_handle = actor_ref.downcast_handle(cx).unwrap(); + let (send, recv) = cx.open_once_port(); + let send = send.bind(); + let send = EitherPortRef::Once(send.into()); + let message = PythonMessage { + response_port: Some(send), + ..params.python_message + }; + let local_state: Result> = + Python::with_gil(|py| { + params + .local_state + .into_iter() + .map(|elem| match elem { + LocalState::Mailbox => Ok(monarch_hyperactor::actor::LocalState::Mailbox), + // SAFETY: python is gonna make unsafe copies of this stuff anyway + LocalState::Ref(r) => unsafe { + let x = self.ref_to_rvalue(&r)?.try_to_object_unsafe(py)?.into(); + Ok(monarch_hyperactor::actor::LocalState::PyObject(x)) + }, + }) + .collect() + }); + // including making the PyMailbox + let message = LocalPythonMessage { + message, + local_state: Some(local_state?), + }; + actor_handle.send(message)?; + let result = recv.recv().await?.with_rank(worker_actor_id.rank()); + if result.method == "exception" { + // If result has "exception" as its kind, then + // we need to unpickle and turn it into a WorkerError + // and call remote_function_failed otherwise the + // controller assumes the object is correct and doesn't handle + // dependency tracking correctly. + let err = Python::with_gil(|py| -> Result { + let err = py + .import("pickle") + .unwrap() + .call_method1("loads", (result.message.into_vec(),))?; + Ok(WorkerError { + worker_actor_id, + backtrace: err.to_string(), + }) + })?; + self.controller_actor + .remote_function_failed(cx, params.seq, err) + .await?; + } else { + let result = Serialized::serialize(&result).unwrap(); + self.controller_actor + .fetch_result(cx, params.seq, Ok(result)) + .await?; + } + Ok(()) + } + async fn set_value( &mut self, cx: &Context, diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index e16261be..5f155db1 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -9,7 +9,9 @@ import collections import contextvars import functools +import importlib import inspect +import itertools import logging import random import traceback @@ -60,10 +62,13 @@ from monarch._src.actor.future import Future from monarch._src.actor.pdb_wrapper import PdbWrapper -from monarch._src.actor.pickle import flatten, unpickle +from monarch._src.actor.pickle import flatten, unflatten from monarch._src.actor.shape import MeshTrait, NDSlice +if TYPE_CHECKING: + from monarch._src.actor.proc_mesh import ProcMesh + logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -137,36 +142,41 @@ def __init__( self, mailbox: Mailbox, hy_actor_mesh: Optional[PythonActorMesh], + proc_mesh: "Optional[ProcMesh]", shape: Shape, actor_ids: List[ActorId], ) -> None: self._mailbox = mailbox self._actor_mesh = hy_actor_mesh + # actor meshes do not have a way to look this up at the moment, + # so we fake it here + self._proc_mesh = proc_mesh self._shape = shape self._please_replace_me_actor_ids = actor_ids @staticmethod def from_hyperactor_mesh( - mailbox: Mailbox, hy_actor_mesh: PythonActorMesh + mailbox: Mailbox, hy_actor_mesh: PythonActorMesh, proc_mesh: "ProcMesh" ) -> "_ActorMeshRefImpl": shape: Shape = hy_actor_mesh.shape return _ActorMeshRefImpl( mailbox, hy_actor_mesh, + proc_mesh, hy_actor_mesh.shape, [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], ) @staticmethod def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": - return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id]) + return _ActorMeshRefImpl(mailbox, None, None, singleton_shape, [actor_id]) @staticmethod def from_actor_ref_with_shape( ref: "_ActorMeshRefImpl", shape: Shape ) -> "_ActorMeshRefImpl": return _ActorMeshRefImpl( - ref._mailbox, None, shape, ref._please_replace_me_actor_ids + ref._mailbox, None, None, shape, ref._please_replace_me_actor_ids ) def __getstate__( @@ -222,6 +232,11 @@ def cast( def __len__(self) -> int: return len(self._shape) + @property + def _name_pid(self): + actor_id0 = self._please_replace_me_actor_ids[0] + return actor_id0.actor_name, actor_id0.pid + class Extent(NamedTuple): labels: Sequence[str] @@ -367,13 +382,21 @@ def _send( This sends the message to all actors but does not wait for any result. """ self._signature.bind(None, *args, **kwargs) + + objects, bytes = flatten((args, kwargs), _is_ref_or_mailbox) message = PythonMessage( self._name, - _pickle((args, kwargs)), + bytes, None if port is None else port._port_ref, None, ) - self._actor_mesh.cast(message, selection) + refs = [obj for obj in objects if hasattr(obj, "__monarch_ref__")] + if not refs: + self._actor_mesh.cast(message, selection) + else: + importlib.import_module("monarch." + "mesh_controller").actor_send( + self, message, refs, port + ) shape = self._actor_mesh._shape return Extent(shape.labels, shape.ndslice.sizes) @@ -569,7 +592,7 @@ def _blocking_recv(self) -> R: def _process(self, msg: PythonMessage) -> R: # TODO: Try to do something more structured than a cast here - payload = cast(R, unpickle(msg.message, self._mailbox)) + payload = cast(R, unflatten(msg.message, itertools.repeat(self._mailbox))) if msg.method == "result": return payload else: @@ -616,7 +639,10 @@ async def handle( shape: Shape, message: PythonMessage, panic_flag: PanicFlag, + local_state: List[Any] | None, ) -> None: + if local_state is None: + local_state = itertools.repeat(mailbox) port = ( Port(message.response_port, mailbox, rank) if message.response_port @@ -630,11 +656,13 @@ async def handle( DebugContext.set(DebugContext()) - args, kwargs = unpickle(message.message, mailbox) + args, kwargs = unflatten(message.message, local_state) if message.method == "__init__": Class, *args = args self.instance = Class(*args, **kwargs) + if port is not None: + port.send("result", None) return None if self.instance is None: @@ -735,9 +763,17 @@ def _post_mortem_debug(self, exc_tb) -> None: def _is_mailbox(x: object) -> bool: + if hasattr(x, "__monarch_ref__"): + raise NotImplementedError( + "Sending monarch tensor references directly to a port." + ) return isinstance(x, Mailbox) +def _is_ref_or_mailbox(x: object) -> bool: + return hasattr(x, "__monarch_ref__") or isinstance(x, Mailbox) + + def _pickle(obj: object) -> bytes: _, msg = flatten(obj, _is_mailbox) return msg diff --git a/python/monarch/_src/actor/pickle.py b/python/monarch/_src/actor/pickle.py index 5b2167df..62268d07 100644 --- a/python/monarch/_src/actor/pickle.py +++ b/python/monarch/_src/actor/pickle.py @@ -5,9 +5,8 @@ # LICENSE file in the root directory of this source tree. import io -import itertools import pickle -from contextlib import contextmanager, nullcontext +from contextlib import contextmanager, ExitStack from typing import Any, Callable, Iterable, List, Tuple import cloudpickle @@ -51,12 +50,10 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]: def unflatten(data: bytes, values: Iterable[Any]) -> Any: - if torch is not None: - context_manager = torch.utils._python_dispatch._disable_current_modes - else: - context_manager = nullcontext - - with context_manager(): + with ExitStack() as stack: + if torch is not None: + stack.enter_context(load_tensors_on_cpu()) + stack.enter_context(torch.utils._python_dispatch._disable_current_modes()) up = _Unpickler(data, values) return up.load() @@ -73,20 +70,3 @@ def load_tensors_on_cpu(): yield finally: torch.storage._load_from_bytes = old - - -def unpickle(data: bytes, mailbox) -> Any: - if torch is not None: - context_manager = load_tensors_on_cpu - else: - context_manager = nullcontext - - with context_manager(): - # regardless of the mailboxes of the remote objects - # they all become the local mailbox. - return unflatten(data, itertools.repeat(mailbox)) - - -def pickle_(obj: object, filter: Callable[[Any], bool]) -> bytes: - _, msg = flatten(obj, filter) - return msg diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 3195dc67..81465216 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -172,7 +172,7 @@ def _spawn_blocking( actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor) service = ActorMeshRef( Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self), self._mailbox, ) # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by @@ -197,7 +197,7 @@ async def _spawn_nonblocking( actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor) service = ActorMeshRef( Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh, self), self._mailbox, ) # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index 5e600912..13062fb6 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -19,9 +19,12 @@ Protocol, Tuple, TYPE_CHECKING, + Union, ) from monarch._rust_bindings.monarch_extension import tensor_worker +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._src.actor.shape import NDSlice from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction @@ -425,6 +428,16 @@ def to_rust_message(self) -> tensor_worker.WorkerMessage: ) +class SendResultOfActorCall(NamedTuple): + seq: int + actor: str + actor_index: int + python_message: PythonMessage + local_state: List[Referenceable | Mailbox] + mutates: List[tensor_worker.Ref] + stream: tensor_worker.StreamRef + + class SplitComm(NamedTuple): dims: Dims device_mesh: DeviceMesh diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index 2b7973d3..900f0c44 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -30,18 +30,19 @@ WorldState, ) from monarch._rust_bindings.monarch_extension.mesh_controller import _Controller +from monarch._rust_bindings.monarch_hyperactor.actor import PythonMessage from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) -from monarch._src.actor.actor_mesh import Port, PortTuple +from monarch._src.actor.actor_mesh import ActorMeshRef, Port, PortTuple from monarch._src.actor.shape import NDSlice -from monarch.common import messages +from monarch.common import device_mesh, messages, stream from monarch.common.controller_api import TController from monarch.common.invocation import Seq +from monarch.common.messages import Referenceable, SendResultOfActorCall from monarch.common.stream import StreamRef -from monarch.common.tensor import Tensor - +from monarch.common.tensor import InputChecker, Tensor from monarch.tensor_worker_main import _set_trace if TYPE_CHECKING: @@ -255,3 +256,59 @@ def __str__(self): except Exception: traceback.print_exc() return "" + + +def actor_send( + actor_mesh: ActorMeshRef, + message: PythonMessage, + refs: Sequence[Any], + port: Optional[Port[Any]], +): + tensors = [ref for ref in refs if isinstance(ref, Tensor)] + # we have some monarch references, we need to ensure their + # proc_mesh matches that of the tensors we sent to it + chosen_stream = stream._active + for t in tensors: + if hasattr(t, "stream"): + chosen_stream = t.stream + break + with InputChecker(refs, lambda x: f"actor_call({x})") as checker: + checker.check_mesh_stream_local(device_mesh._active, chosen_stream) + # TODO: move propagators into Endpoint abstraction and run the propagator to get the + # mutates + checker.check_permission(()) + selected_device_mesh = ( + actor_mesh._actor_mesh._proc_mesh + and actor_mesh._actor_mesh._proc_mesh._device_mesh + ) + if selected_device_mesh is not checker.mesh: + raise ValueError( + f"monarch Tensors sent to an actor must be located on the same process as the actor. However {checker.mesh} is not {selected_device_mesh}." + "NYI: better serialization of mesh names to make the mismatch more clear." + ) + + client = checker.mesh.client + stream_ref = chosen_stream._to_ref(client) + + fut = (port, checker.mesh._ndslice) if port is not None else None + + ident = client.new_node([], tensors, cast("OldFuture", fut)) + + actor_name, actor_index = actor_mesh._actor_mesh._name_pid + msg = SendResultOfActorCall( + ident, + actor_name, + actor_index, + message, + cast("List[Mailbox | Referenceable]", refs), + [], + stream_ref, + ) + + print("SENDING MESSAGE", msg) + client.send(checker.mesh._ndslice, msg) + # we have to ask for status updates + # from workers to be sure they have finished + # enough work to count this future as finished, + # and all potential errors have been reported + client._request_status() diff --git a/python/tests/test_tensor_engine.py b/python/tests/test_tensor_engine.py index a7e5122b..a79d5e8d 100644 --- a/python/tests/test_tensor_engine.py +++ b/python/tests/test_tensor_engine.py @@ -8,7 +8,7 @@ import pytest import torch from monarch import remote -from monarch.actor import proc_mesh +from monarch.actor import Actor, endpoint, proc_mesh from monarch.mesh_controller import spawn_tensor_engine @@ -59,3 +59,22 @@ def test_proc_mesh_tensor_engine() -> None: assert a == 0 assert b == 10 assert c == 100 + + +class AddWithState(Actor): + def __init__(self, state: torch.Tensor): + super().__init__() + self.state = state + + @endpoint + def forward(self, x) -> torch.Tensor: + return x + self.state + + +@two_gpu +def test_actor_with_tensors() -> None: + pm = proc_mesh(gpus=1).get() + with pm.activate(): + x = pm.spawn("adder", AddWithState, torch.ones(())).get() + y = torch.ones(()) + assert x.forward.call(y).get(timeout=5).item(hosts=0, gpus=0).item() == 2