Skip to content

Adds monarch.actor_mesh.stack #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 210 additions & 3 deletions python/monarch/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import sys
import traceback

from abc import ABC, abstractmethod

from dataclasses import dataclass
from traceback import extract_tb, StackSummary
from typing import (
Expand Down Expand Up @@ -204,7 +206,41 @@ def __len__(self) -> int:
return len(self._shape)


class Endpoint(Generic[P, R]):
class EndpointInterface(Generic[P, R], ABC):
"""
Interface defining the common methods that all endpoint types must implement.
This ensures consistent behavior between regular endpoints and stacked endpoints.
"""

@abstractmethod
def choose(self, *args, **kwargs) -> Future[R]:
"""Load balanced sends a message to one chosen actor and awaits a result."""
pass

@abstractmethod
def call(self, *args, **kwargs) -> "Future":
"""Sends a message to all actors and collects all results."""
pass

@abstractmethod
def stream(self, *args, **kwargs) -> AsyncGenerator[R, R]:
"""Broadcasts to all actors and yields their responses as a stream.

Note: This method is intended to be implemented as an async method in concrete
classes, but the abstract definition must omit the 'async' keyword to avoid
type checking errors @abstractmethod.
Implementations should use `async def stream` in their signatures.

"""
pass

@abstractmethod
def broadcast(self, *args, **kwargs) -> None:
"""Fire-and-forget broadcast to all actors."""
pass


class Endpoint(Generic[P, R], EndpointInterface[P, R]):
def __init__(
self,
actor_mesh_ref: _ActorMeshRefImpl,
Expand Down Expand Up @@ -298,11 +334,61 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
send(self, args, kwargs)


class StackedEndpoint(Generic[P, R], EndpointInterface[P, R]):
"""
A class that represents a collection of endpoints stacked together.

This class allows for operations to be performed across multiple endpoints
as if they were a single entity.

This provides the same interface as the Endpoint class.

"""

def __init__(self, endpoints: list[Endpoint]) -> None:
self.endpoints = endpoints

def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]:
"""Load balanced sends a message to one chosen actor from all stacked actors."""
idx = _load_balancing_seed.randrange(len(self.endpoints))
endpoint = self.endpoints[idx]
return endpoint.choose(*args, **kwargs)

def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[list[ValueMesh[R]]]":
"""Sends a message to all actors in all stacked endpoints and collects results."""
futures = [endpoint.call(*args, **kwargs) for endpoint in self.endpoints]

async def process() -> list[ValueMesh[R]]:
results = []
for future in futures:
results.append(await future)
return results

def process_blocking() -> list[ValueMesh[R]]:
return [future.get() for future in futures]

return Future(process, process_blocking)

async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]:
"""Broadcasts to all actors in all stacked endpoints and yields responses as a stream."""
for endpoint in self.endpoints:
async for result in endpoint.stream(*args, **kwargs):
yield result

def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None:
"""Fire-and-forget broadcast to all actors in all stacked endpoints."""
for endpoint in self.endpoints:
endpoint.broadcast(*args, **kwargs)


class Accumulator(Generic[P, R, A]):
def __init__(
self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A]
self,
endpoint: EndpointInterface[P, R],
identity: A,
combine: Callable[[A, R], A],
) -> None:
self._endpoint: Endpoint[P, R] = endpoint
self._endpoint: EndpointInterface[P, R] = endpoint
self._identity: A = identity
self._combine: Callable[[A, R], A] = combine

Expand Down Expand Up @@ -772,3 +858,124 @@ def current_rank() -> Point:
def current_size() -> Dict[str, int]:
ctx = MonarchContext.get()
return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes))


class StackedActorMeshRef(MeshTrait, Generic[T]):
def __init__(self, *actors: ActorMeshRef[T], interface=None) -> None:
self._actors = actors
self._interface = interface

# Create endpoints for all methods in the interface
for attr_name in dir(interface):
attr_value = getattr(interface, attr_name, None)
if isinstance(attr_value, EndpointProperty):
# Get the corresponding endpoint from each mesh
endpoints = []
for mesh in self._actors:
if hasattr(mesh, attr_name):
endpoints.append(getattr(mesh, attr_name))

# Create a stacked endpoint with all the collected endpoints
if endpoints:
setattr(self, attr_name, StackedEndpoint(endpoints))

@property
def _ndslice(self) -> NDSlice:
raise NotImplementedError(
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)

@property
def _labels(self) -> Tuple[str, ...]:
raise NotImplementedError(
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)

def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
raise NotImplementedError(
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)


def _common_ancestor(*actors: ActorMeshRef[T]) -> Optional[Type]:
"""Finds the common ancestor class of a list of actor mesh references.

This determines the most specific common base class shared by all
provided actors.

Args:
*actors: Variable number of ActorMeshRef instances to analyze

Returns:
Optional[Type]: The most specific common ancestor class, or None if no
actors were provided or no common ancestor exists

Example:
```python
# Find common ancestor of two counter actors
counter_a = proc.spawn("counter_a", CounterA, 0).get()
counter_b = proc.spawn("counter_b", CounterB, 0).get()
common_class = _common_ancestor(counter_a, counter_b) # Returns Counter
```
"""
if not actors:
return None
base_classes = [obj._class for obj in actors]
all_mros = [inspect.getmro(cls) for cls in base_classes]
common_bases = set(all_mros[0]).intersection(*all_mros[1:])
if common_bases:
return min(
common_bases, key=lambda cls: min(mro.index(cls) for mro in all_mros)
)
return None


def stack(
*actors: ActorMeshRef[T], interface: Optional[Type] = None
) -> StackedActorMeshRef[T]:
"""Stacks multiple actor mesh references into a single unified interface.

This allows you to combine multiple actors that share a common interface
into a single object that can be used to interact with all of them simultaneously.
When methods are called on the stacked actor, they are distributed to all
underlying actors according to the endpoint's behavior (choose, call, stream, etc).

Args:
*actors: Variable number of ActorMeshRef instances to stack together
interface: Optional class that defines the interface to expose. If not provided,
the common ancestor class of all actors will be used.

Returns:
StackedActorMeshRef: A reference that provides access to all stacked actors
through a unified interface.

Raises:
TypeError: If any of the provided objects is not an ActorMeshRef, or if
no common ancestor can be found and no interface is provided.

Example:
```python
# Stack two counter actors together
counter1 = proc1.spawn("counter1", Counter, 0).get()
counter2 = proc2.spawn("counter2", Counter, 0).get()
stacked = stack(counter1, counter2)

# Call methods on all actors at once
stacked.incr.broadcast() # Increments both counters
```

"""
for actor in actors:
if not isinstance(actor, ActorMeshRef):
raise TypeError(
"stack be provided with Monarch Actors, got {}".format(type(actor))
)
if interface is None:
interface = _common_ancestor(*actors)

if interface is None or interface == Actor:
raise TypeError(
"No common ancestor found for the given actors. Please provide an interface explicitly."
)
logging.debug("Stacking actors %s with interface %s", actors, interface)
return StackedActorMeshRef(*actors, interface=interface)
Loading