diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi index eaa8d37a..b13de2f6 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi @@ -5,7 +5,7 @@ # LICENSE file in the root directory of this source tree. import collections.abc -from typing import Dict, final, Iterator, List, overload +from typing import Dict, final, Iterator, List, overload, Sequence @final class Slice: @@ -90,7 +90,7 @@ class Slice: ... @staticmethod - def new_row_major(ranks: list[int]) -> "Slice": + def new_row_major(ranks: Sequence[int]) -> "Slice": """Returns a contiguous slice composed of ranks""" ... @@ -106,7 +106,7 @@ class Shape: - `labels`: A list of strings representing the labels for each dimension. - `slice`: An Slice object representing the shape. """ - def __new__(cls, labels: List[str], slice: Slice) -> "Shape": ... + def __new__(cls, labels: Sequence[str], slice: Slice) -> "Shape": ... @property def ndslice(self) -> Slice: ... @property diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f22c9646..cec2e8a6 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -15,7 +15,10 @@ import sys import traceback +from abc import ABC, abstractmethod + from dataclasses import dataclass +from operator import mul from traceback import extract_tb, StackSummary from typing import ( Any, @@ -27,11 +30,13 @@ Dict, Generic, Iterable, + Iterator, List, Literal, NamedTuple, Optional, ParamSpec, + Sequence, Tuple, Type, TYPE_CHECKING, @@ -204,18 +209,41 @@ def __len__(self) -> int: return len(self._shape) -class Endpoint(Generic[P, R]): - def __init__( +class Extent(NamedTuple): + labels: Sequence[str] + sizes: Sequence[int] + + @property + def nelements(self) -> int: + return functools.reduce(mul, self.sizes, 1) + + def __str__(self) -> str: + return str(dict(zip(self.labels, self.sizes))) + + +class Endpoint(ABC, Generic[P, R]): + @abstractmethod + def _send( self, - actor_mesh_ref: _ActorMeshRefImpl, - name: str, - impl: Callable[Concatenate[Any, P], Awaitable[R]], - mailbox: Mailbox, - ) -> None: - self._actor_mesh = actor_mesh_ref - self._name = name - self._signature: inspect.Signature = inspect.signature(impl) - self._mailbox = mailbox + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + port: "Optional[Port]" = None, + selection: Selection = "all", + ) -> Extent: + """ + Implements sending a message to the endpoint. The return value of the endpoint will + be sent to port if provided. If port is not provided, the return will be dropped, + and any exception will cause the actor to fail. + + The return value is the (multi-dimension) size of the actors that were sent a message. + For ActorEndpoints this will be the actor_meshes size. For free-function endpoints, + this will be the size of the currently active proc_mesh. + """ + pass + + @abstractmethod + def _port(self, once: bool = False) -> "PortTuple[R]": + pass # the following are all 'adverbs' or different ways to handle the # return values of this endpoint. Adverbs should only ever take *args, **kwargs @@ -228,46 +256,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: Load balanced RPC-style entrypoint for request/response messaging. """ - p: Port[R] - r: PortReceiver[R] p, r = port(self, once=True) # pyre-ignore - send(self, args, kwargs, port=p, selection="choose") + self._send(args, kwargs, port=p, selection="choose") return r.recv() def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - if len(self._actor_mesh) != 1: + p, r = port(self, once=True) + # pyre-ignore + extent = self._send(args, kwargs, port=p, selection="choose") + if extent.nelements != 1: raise ValueError( - f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}" + f"Can only use 'call_one' on a single Actor but this actor has shape {extent}" ) - return self.choose(*args, **kwargs) + return r.recv() def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": p: Port[R] r: RankedPortReceiver[R] p, r = ranked_port(self) # pyre-ignore - send(self, args, kwargs, port=p) + extent = self._send(args, kwargs, port=p) async def process() -> ValueMesh[R]: - results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] - for _ in range(len(self._actor_mesh)): + results: List[R] = [None] * extent.nelements # pyre-fixme[9] + for _ in range(extent.nelements): rank, value = await r.recv() results[rank] = value call_shape = Shape( - self._actor_mesh._shape.labels, - NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), + extent.labels, + NDSlice.new_row_major(extent.sizes), ) return ValueMesh(call_shape, results) def process_blocking() -> ValueMesh[R]: - results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] - for _ in range(len(self._actor_mesh)): + results: List[R] = [None] * extent.nelements # pyre-fixme[9] + for _ in range(extent.nelements): rank, value = r.recv().get() results[rank] = value call_shape = Shape( - self._actor_mesh._shape.labels, - NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), + extent.labels, + NDSlice.new_row_major(extent.sizes), ) return ValueMesh(call_shape, results) @@ -282,8 +311,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R """ p, r = port(self) # pyre-ignore - send(self, args, kwargs, port=p) - for _ in range(len(self._actor_mesh)): + extent = self._send(args, kwargs, port=p) + for _ in range(extent.nelements): yield await r.recv() def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: @@ -298,6 +327,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: send(self, args, kwargs) +class ActorEndpoint(Endpoint[P, R]): + def __init__( + self, + actor_mesh_ref: _ActorMeshRefImpl, + name: str, + impl: Callable[Concatenate[Any, P], Awaitable[R]], + mailbox: Mailbox, + ) -> None: + self._actor_mesh = actor_mesh_ref + self._name = name + self._signature: inspect.Signature = inspect.signature(impl) + self._mailbox = mailbox + + def _send( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + port: "Optional[Port]" = None, + selection: Selection = "all", + ) -> Extent: + """ + Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh. + + This sends the message to all actors but does not wait for any result. + """ + self._signature.bind(None, *args, **kwargs) + message = PythonMessage( + self._name, + _pickle((args, kwargs)), + None if port is None else port._port_ref, + None, + ) + self._actor_mesh.cast(message, selection) + shape = self._actor_mesh._shape + return Extent(shape.labels, shape.ndslice.sizes) + + def _port(self, once: bool = False) -> "PortTuple[R]": + return PortTuple.create(self._mailbox, once) + + class Accumulator(Generic[P, R, A]): def __init__( self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A] @@ -337,10 +406,13 @@ def item(self, **kwargs) -> R: return self._values[self._ndslice.nditem(coordinates)] - def __iter__(self): + def items(self) -> Iterable[Tuple[Point, R]]: for rank in self._shape.ranks(): yield Point(rank, self._shape), self._values[rank] + def __iter__(self) -> Iterator[Tuple[Point, R]]: + return iter(self.items()) + def __len__(self) -> int: return len(self._shape) @@ -368,14 +440,7 @@ def send( This sends the message to all actors but does not wait for any result. """ - endpoint._signature.bind(None, *args, **kwargs) - message = PythonMessage( - endpoint._name, - _pickle((args, kwargs)), - None if port is None else port._port_ref, - None, - ) - endpoint._actor_mesh.cast(message, selection) + endpoint._send(args, kwargs, port, selection) class EndpointProperty(Generic[P, R]): @@ -447,7 +512,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": # not part of the Endpoint API because they way it accepts arguments # and handles concerns is different. def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": - return PortTuple.create(endpoint._mailbox, once) + return endpoint._port(once) def ranked_port( @@ -676,7 +741,7 @@ def __init__( setattr( self, attr_name, - Endpoint( + ActorEndpoint( self._actor_mesh_ref, attr_name, attr_value._method, @@ -695,7 +760,7 @@ def __getattr__(self, name: str) -> Any: attr = getattr(self._class, name) if isinstance(attr, EndpointProperty): # Dynamically create the endpoint - endpoint = Endpoint( + endpoint = ActorEndpoint( self._actor_mesh_ref, name, attr._method, @@ -718,7 +783,7 @@ def _create( async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: return None - ep = Endpoint( + ep = ActorEndpoint( self._actor_mesh_ref, "__init__", null_func, diff --git a/python/monarch/common/device_mesh.py b/python/monarch/common/device_mesh.py index 86efc9e2..4704dcd9 100644 --- a/python/monarch/common/device_mesh.py +++ b/python/monarch/common/device_mesh.py @@ -32,6 +32,7 @@ from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map +from torch.utils.weak import weakref from ._tensor_to_table import tensor_to_table from .context_manager import activate_first_context_manager @@ -170,6 +171,7 @@ def __init__( self.exit = lambda: None self.ref = None self._active_mesh_context = None + self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None def define_remotely(self): if self.ref is None: @@ -227,8 +229,17 @@ def _labels(self) -> Tuple[str, ...]: def _new_with_shape(self, shape: Shape) -> "DeviceMesh": mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels)) mesh.exit = self.exit + mesh._subset_of = weakref.ref(self) return mesh + def _is_subset_of(self, other: "DeviceMesh") -> bool: + p = self + while p is not None: + if p is other: + return True + p = None if p._subset_of is None else p._subset_of() + return False + def __call__(self, **kwargs) -> "DeviceMesh": """ device_mesh(batch=3) or device_mesh(batch=slice(3, None)) diff --git a/python/monarch/common/remote.py b/python/monarch/common/remote.py index fa1a9919..1c19a360 100644 --- a/python/monarch/common/remote.py +++ b/python/monarch/common/remote.py @@ -13,6 +13,7 @@ from typing import ( Any, Callable, + cast, Dict, Generic, Literal, @@ -27,12 +28,17 @@ import monarch.common.messages as messages import torch +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox +from monarch._rust_bindings.monarch_hyperactor.shape import Shape +from monarch._src.actor.actor_mesh import Extent, Port, PortTuple, Selection from monarch.common import _coalescing, device_mesh, stream +from monarch.common.future import Future as OldFuture if TYPE_CHECKING: from monarch.common.client import Client +from monarch._src.actor.actor_mesh import Endpoint from monarch.common.device_mesh import RemoteProcessGroup from monarch.common.fake import fake_call @@ -48,9 +54,9 @@ TensorGroup, TensorPlaceholder, ) -from monarch.common.future import Future from monarch.common.messages import Dims -from monarch.common.tensor import dtensor_check, dtensor_dispatch + +from monarch.common.tensor import dtensor_check, dtensor_dispatch, InputChecker from monarch.common.tree import flatten, tree_map from torch import autograd, distributed as dist from typing_extensions import ParamSpec @@ -62,23 +68,91 @@ T = TypeVar("T") -class Remote(Generic[P, R]): +class Remote(Generic[P, R], Endpoint[P, R]): def __init__(self, impl: Any, propagator_arg: Propagator): self._remote_impl = impl self._propagator_arg = propagator_arg self._cache: Optional[dict] = None + def _send( + self, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + port: "Optional[Port]" = None, + selection: Selection = "all", + ) -> Extent: + ambient_mesh = device_mesh._active + propagator = self._fetch_propagate + rfunction = self._maybe_resolvable + # a None rfunction is an optimization for the identity function (lambda x: x) + if rfunction is None: + preprocess_message = None + rfunction = ResolvableFunctionFromPath("ident") + else: + preprocess_message = rfunction + _, dtensors, mutates, tensor_mesh = dtensor_check( + propagator, rfunction, args, kwargs, ambient_mesh, stream._active + ) + + if ambient_mesh is None: + raise ValueError( + "Calling a 'remote' monarch function requires an active proc_mesh (`with proc_mesh.activate():`)" + ) + + if not ambient_mesh._is_subset_of(tensor_mesh): + raise ValueError( + f"The current mesh {ambient_mesh} is not a subset of the mesh on which the tensors being used are defined {tensor_mesh}" + ) + + client: "Client" = ambient_mesh.client + if _coalescing.is_active(client): + raise NotImplementedError("NYI: fetching results during a coalescing block") + stream_ref = stream._active._to_ref(client) + + fut = (port, ambient_mesh._ndslice) + + ident = client.new_node(mutates, dtensors, cast("OldFuture", fut)) + + client.send( + ambient_mesh._ndslice, + messages.SendValue( + ident, + None, + mutates, + preprocess_message, + args, + kwargs, + stream_ref, + ), + ) + # we have to ask for status updates + # from workers to be sure they have finished + # enough work to count this future as finished, + # and all potential errors have been reported + client._request_status() + return Extent(ambient_mesh._labels, ambient_mesh._ndslice.sizes) + + def _port(self, once: bool = False) -> "PortTuple[R]": + ambient_mesh = device_mesh._active + if ambient_mesh is None: + raise ValueError( + "FIXME - cannot create a port without an active proc_mesh, because there is not way to create a port without a mailbox" + ) + mesh_controller = getattr(ambient_mesh.client, "_mesh_controller", None) + if mesh_controller is None: + raise ValueError( + "Cannot create raw port objects with an old-style tensor engine controller." + ) + mailbox: Mailbox = mesh_controller._mailbox + return PortTuple.create(mailbox, once) + @property def _resolvable(self): return resolvable_function(self._remote_impl) @property def _maybe_resolvable(self): - return ( - None - if self._remote_impl is None - else resolvable_function(self._remote_impl) - ) + return None if self._remote_impl is None else self._resolvable def _propagate(self, args, kwargs, fake_args, fake_kwargs): if self._propagator_arg is None or self._propagator_arg == "cached": @@ -102,7 +176,7 @@ def _pipe_propagate(self, args, kwargs, fake_args, fake_kwargs): raise ValueError("Must specify explicit callable for pipe") return self._propagate(args, kwargs, fake_args, fake_kwargs) - def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + def rref(self, *args: P.args, **kwargs: P.kwargs) -> R: return dtensor_dispatch( self._resolvable, self._propagate, @@ -112,6 +186,9 @@ def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: stream._active, ) + def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R: + return self.rref(*args, **kwargs) + # This can't just be Callable because otherwise we are not # allowed to use type arguments in the return value. @@ -153,12 +230,39 @@ def remote(function: Any = None, *, propagate: Propagator = None) -> Any: def call_on_shard_and_fetch( + remote: Remote[P, R], *args, shard: Dict[str, int] | None = None, **kwargs +) -> OldFuture[R]: + # We have to flatten the tensors twice: first to discover + # which mesh we are working on to shard it, and then again when doing the + # dtensor_check in send. This complexity is a consequence of doing + # implicit inference of the mesh from the tensors. + dtensors, unflatten = flatten((args, kwargs), lambda x: isinstance(x, torch.Tensor)) + with InputChecker.from_flat_args( + remote._remote_impl, dtensors, unflatten + ) as checker: + checker.check_mesh_stream_local(device_mesh._active, stream._active) + + if not hasattr(checker.mesh.client, "_mesh_controller"): + return _old_call_on_shard_and_fetch( + remote, + *args, + shard=shard, + **kwargs, + ) + + selected_slice = checker.mesh._process(shard) + shard_mesh = checker.mesh._new_with_shape(Shape(["_"], selected_slice)) + with shard_mesh.activate(): + return cast("OldFuture[R]", remote.call_one(*args, **kwargs)) + + +def _old_call_on_shard_and_fetch( remote_obj: Remote[P, R], /, *args: object, shard: dict[str, int] | None = None, **kwargs: object, -) -> Future[R]: +) -> OldFuture[R]: """ Call `function` at the coordinates `shard` of the current device mesh, and retrieve the result as a Future. function - the remote function to call diff --git a/python/monarch/fetch.py b/python/monarch/fetch.py index 266e96af..d09a75b8 100644 --- a/python/monarch/fetch.py +++ b/python/monarch/fetch.py @@ -9,7 +9,7 @@ This is a utility file for fetching a shard of a tensor from remote. """ -from typing import TypeVar +from typing import cast, TypeVar from monarch.common.device_mesh import no_mesh @@ -37,7 +37,7 @@ def fetch_shard( shard = {} shard.update(kwargs) - return call_on_shard_and_fetch(remote_identity, obj, shard=shard) + return cast("Future[T]", call_on_shard_and_fetch(remote_identity, obj, shard=shard)) def show(obj: T, shard: dict[str, int] | None = None, **kwargs: int) -> object: diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index b36e37a9..2b7973d3 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -204,8 +204,7 @@ def new_node_nocoalesce( response_port = None if future is not None: # method annotation is a lie to make Client happy - port = cast("Port[Any]", future) - slice = NDSlice.new_row_major([]) + port, slice = cast("Tuple[Port[Any], NDSlice]", future) response_port = (port._port_ref.port_id, slice) self._mesh_controller.node(seq, defs, uses, response_port, tracebacks) return seq diff --git a/python/tests/test_remote_functions.py b/python/tests/test_remote_functions.py index 3f40d0dd..42adfd82 100644 --- a/python/tests/test_remote_functions.py +++ b/python/tests/test_remote_functions.py @@ -9,7 +9,6 @@ import math import sys import traceback -from enum import Enum from typing import Callable, ContextManager, Tuple from unittest.mock import patch @@ -1277,3 +1276,24 @@ def a_function_called_by_a_live_function(x): def a_live_function_call_by_a_live_function(x): return 3 * x + + +@remote +def return_them(x: torch.Tensor, y: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return (x, y) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +class TestMeshSpecific(RemoteFunctionsTestBase): + def test_value_mesh(self): + with self.local_device_mesh(2, 2, "mesh") as device_mesh: + x = device_mesh.rank("host") + y = device_mesh.rank("gpu") + r = return_them.call(x, y).get() + + for p, (h, g) in r: + assert p["host"] == h.item() + assert p["gpu"] == g.item()