Skip to content

[13/n] tensor engine, free functions implement call/broadcast/stream/etc., temp #478

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: gh/zdevito/28/base
Choose a base branch
from
6 changes: 3 additions & 3 deletions python/monarch/_rust_bindings/monarch_hyperactor/shape.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
# LICENSE file in the root directory of this source tree.

import collections.abc
from typing import Dict, final, Iterator, List, overload
from typing import Dict, final, Iterator, List, overload, Sequence

@final
class Slice:
Expand Down Expand Up @@ -90,7 +90,7 @@ class Slice:
...

@staticmethod
def new_row_major(ranks: list[int]) -> "Slice":
def new_row_major(ranks: Sequence[int]) -> "Slice":
"""Returns a contiguous slice composed of ranks"""
...

Expand All @@ -106,7 +106,7 @@ class Shape:
- `labels`: A list of strings representing the labels for each dimension.
- `slice`: An Slice object representing the shape.
"""
def __new__(cls, labels: List[str], slice: Slice) -> "Shape": ...
def __new__(cls, labels: Sequence[str], slice: Slice) -> "Shape": ...
@property
def ndslice(self) -> Slice: ...
@property
Expand Down
147 changes: 106 additions & 41 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import sys
import traceback

from abc import ABC, abstractmethod

from dataclasses import dataclass
from operator import mul
from traceback import extract_tb, StackSummary
from typing import (
Any,
Expand All @@ -27,11 +30,13 @@
Dict,
Generic,
Iterable,
Iterator,
List,
Literal,
NamedTuple,
Optional,
ParamSpec,
Sequence,
Tuple,
Type,
TYPE_CHECKING,
Expand Down Expand Up @@ -204,18 +209,41 @@ def __len__(self) -> int:
return len(self._shape)


class Endpoint(Generic[P, R]):
def __init__(
class Extent(NamedTuple):
labels: Sequence[str]
sizes: Sequence[int]

@property
def nelements(self) -> int:
return functools.reduce(mul, self.sizes, 1)

def __str__(self) -> str:
return str(dict(zip(self.labels, self.sizes)))


class Endpoint(ABC, Generic[P, R]):
@abstractmethod
def _send(
self,
actor_mesh_ref: _ActorMeshRefImpl,
name: str,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
mailbox: Mailbox,
) -> None:
self._actor_mesh = actor_mesh_ref
self._name = name
self._signature: inspect.Signature = inspect.signature(impl)
self._mailbox = mailbox
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[Port]" = None,
selection: Selection = "all",
) -> Extent:
"""
Implements sending a message to the endpoint. The return value of the endpoint will
be sent to port if provided. If port is not provided, the return will be dropped,
and any exception will cause the actor to fail.

The return value is the (multi-dimension) size of the actors that were sent a message.
For ActorEndpoints this will be the actor_meshes size. For free-function endpoints,
this will be the size of the currently active proc_mesh.
"""
pass

@abstractmethod
def _port(self, once: bool = False) -> "PortTuple[R]":
pass

# the following are all 'adverbs' or different ways to handle the
# return values of this endpoint. Adverbs should only ever take *args, **kwargs
Expand All @@ -228,46 +256,47 @@ def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:

Load balanced RPC-style entrypoint for request/response messaging.
"""
p: Port[R]
r: PortReceiver[R]
p, r = port(self, once=True)
# pyre-ignore
send(self, args, kwargs, port=p, selection="choose")
self._send(args, kwargs, port=p, selection="choose")
return r.recv()

def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
if len(self._actor_mesh) != 1:
p, r = port(self, once=True)
# pyre-ignore
extent = self._send(args, kwargs, port=p, selection="choose")
if extent.nelements != 1:
raise ValueError(
f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}"
f"Can only use 'call_one' on a single Actor but this actor has shape {extent}"
)
return self.choose(*args, **kwargs)
return r.recv()

def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]":
p: Port[R]
r: RankedPortReceiver[R]
p, r = ranked_port(self)
# pyre-ignore
send(self, args, kwargs, port=p)
extent = self._send(args, kwargs, port=p)

async def process() -> ValueMesh[R]:
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
for _ in range(len(self._actor_mesh)):
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
for _ in range(extent.nelements):
rank, value = await r.recv()
results[rank] = value
call_shape = Shape(
self._actor_mesh._shape.labels,
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
extent.labels,
NDSlice.new_row_major(extent.sizes),
)
return ValueMesh(call_shape, results)

def process_blocking() -> ValueMesh[R]:
results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9]
for _ in range(len(self._actor_mesh)):
results: List[R] = [None] * extent.nelements # pyre-fixme[9]
for _ in range(extent.nelements):
rank, value = r.recv().get()
results[rank] = value
call_shape = Shape(
self._actor_mesh._shape.labels,
NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes),
extent.labels,
NDSlice.new_row_major(extent.sizes),
)
return ValueMesh(call_shape, results)

Expand All @@ -282,8 +311,8 @@ async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R
"""
p, r = port(self)
# pyre-ignore
send(self, args, kwargs, port=p)
for _ in range(len(self._actor_mesh)):
extent = self._send(args, kwargs, port=p)
for _ in range(extent.nelements):
yield await r.recv()

def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
Expand All @@ -298,6 +327,46 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
send(self, args, kwargs)


class ActorEndpoint(Endpoint[P, R]):
def __init__(
self,
actor_mesh_ref: _ActorMeshRefImpl,
name: str,
impl: Callable[Concatenate[Any, P], Awaitable[R]],
mailbox: Mailbox,
) -> None:
self._actor_mesh = actor_mesh_ref
self._name = name
self._signature: inspect.Signature = inspect.signature(impl)
self._mailbox = mailbox

def _send(
self,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
port: "Optional[Port]" = None,
selection: Selection = "all",
) -> Extent:
"""
Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh.

This sends the message to all actors but does not wait for any result.
"""
self._signature.bind(None, *args, **kwargs)
message = PythonMessage(
self._name,
_pickle((args, kwargs)),
None if port is None else port._port_ref,
None,
)
self._actor_mesh.cast(message, selection)
shape = self._actor_mesh._shape
return Extent(shape.labels, shape.ndslice.sizes)

def _port(self, once: bool = False) -> "PortTuple[R]":
return PortTuple.create(self._mailbox, once)


class Accumulator(Generic[P, R, A]):
def __init__(
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
Expand Down Expand Up @@ -337,10 +406,13 @@ def item(self, **kwargs) -> R:

return self._values[self._ndslice.nditem(coordinates)]

def __iter__(self):
def items(self) -> Iterable[Tuple[Point, R]]:
for rank in self._shape.ranks():
yield Point(rank, self._shape), self._values[rank]

def __iter__(self) -> Iterator[Tuple[Point, R]]:
return iter(self.items())

def __len__(self) -> int:
return len(self._shape)

Expand Down Expand Up @@ -368,14 +440,7 @@ def send(

This sends the message to all actors but does not wait for any result.
"""
endpoint._signature.bind(None, *args, **kwargs)
message = PythonMessage(
endpoint._name,
_pickle((args, kwargs)),
None if port is None else port._port_ref,
None,
)
endpoint._actor_mesh.cast(message, selection)
endpoint._send(args, kwargs, port, selection)


class EndpointProperty(Generic[P, R]):
Expand Down Expand Up @@ -447,7 +512,7 @@ def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]":
# not part of the Endpoint API because they way it accepts arguments
# and handles concerns is different.
def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]":
return PortTuple.create(endpoint._mailbox, once)
return endpoint._port(once)


def ranked_port(
Expand Down Expand Up @@ -676,7 +741,7 @@ def __init__(
setattr(
self,
attr_name,
Endpoint(
ActorEndpoint(
self._actor_mesh_ref,
attr_name,
attr_value._method,
Expand All @@ -695,7 +760,7 @@ def __getattr__(self, name: str) -> Any:
attr = getattr(self._class, name)
if isinstance(attr, EndpointProperty):
# Dynamically create the endpoint
endpoint = Endpoint(
endpoint = ActorEndpoint(
self._actor_mesh_ref,
name,
attr._method,
Expand All @@ -718,7 +783,7 @@ def _create(
async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None:
return None

ep = Endpoint(
ep = ActorEndpoint(
self._actor_mesh_ref,
"__init__",
null_func,
Expand Down
11 changes: 11 additions & 0 deletions python/monarch/common/device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

from torch.utils._python_dispatch import TorchDispatchMode
from torch.utils._pytree import tree_map
from torch.utils.weak import weakref

from ._tensor_to_table import tensor_to_table
from .context_manager import activate_first_context_manager
Expand Down Expand Up @@ -170,6 +171,7 @@ def __init__(
self.exit = lambda: None
self.ref = None
self._active_mesh_context = None
self._subset_of: Optional[weakref.ReferenceType["DeviceMesh"]] = None

def define_remotely(self):
if self.ref is None:
Expand Down Expand Up @@ -227,8 +229,17 @@ def _labels(self) -> Tuple[str, ...]:
def _new_with_shape(self, shape: Shape) -> "DeviceMesh":
mesh = DeviceMesh(self.client, shape.ndslice, tuple(shape.labels))
mesh.exit = self.exit
mesh._subset_of = weakref.ref(self)
return mesh

def _is_subset_of(self, other: "DeviceMesh") -> bool:
p = self
while p is not None:
if p is other:
return True
p = None if p._subset_of is None else p._subset_of()
return False

def __call__(self, **kwargs) -> "DeviceMesh":
"""
device_mesh(batch=3) or device_mesh(batch=slice(3, None))
Expand Down
Loading