diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index 613fda8b..3aad1648 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -17,6 +17,8 @@ import sys import traceback +from abc import ABC, abstractmethod + from dataclasses import dataclass from traceback import extract_tb, StackSummary from typing import ( @@ -204,7 +206,41 @@ def __len__(self) -> int: return len(self._shape) -class Endpoint(Generic[P, R]): +class EndpointInterface(Generic[P, R], ABC): + """ + Interface defining the common methods that all endpoint types must implement. + This ensures consistent behavior between regular endpoints and stacked endpoints. + """ + + @abstractmethod + def choose(self, *args, **kwargs) -> Future[R]: + """Load balanced sends a message to one chosen actor and awaits a result.""" + pass + + @abstractmethod + def call(self, *args, **kwargs) -> "Future": + """Sends a message to all actors and collects all results.""" + pass + + @abstractmethod + def stream(self, *args, **kwargs) -> AsyncGenerator[R, R]: + """Broadcasts to all actors and yields their responses as a stream. + + Note: This method is intended to be implemented as an async method in concrete + classes, but the abstract definition must omit the 'async' keyword to avoid + type checking errors @abstractmethod. + Implementations should use `async def stream` in their signatures. + + """ + pass + + @abstractmethod + def broadcast(self, *args, **kwargs) -> None: + """Fire-and-forget broadcast to all actors.""" + pass + + +class Endpoint(Generic[P, R], EndpointInterface[P, R]): def __init__( self, actor_mesh_ref: _ActorMeshRefImpl, @@ -298,11 +334,61 @@ def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: send(self, args, kwargs) +class StackedEndpoint(Generic[P, R], EndpointInterface[P, R]): + """ + A class that represents a collection of endpoints stacked together. + + This class allows for operations to be performed across multiple endpoints + as if they were a single entity. + + This provides the same interface as the Endpoint class. + + """ + + def __init__(self, endpoints: list[Endpoint]) -> None: + self.endpoints = endpoints + + def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """Load balanced sends a message to one chosen actor from all stacked actors.""" + idx = _load_balancing_seed.randrange(len(self.endpoints)) + endpoint = self.endpoints[idx] + return endpoint.choose(*args, **kwargs) + + def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[list[ValueMesh[R]]]": + """Sends a message to all actors in all stacked endpoints and collects results.""" + futures = [endpoint.call(*args, **kwargs) for endpoint in self.endpoints] + + async def process() -> list[ValueMesh[R]]: + results = [] + for future in futures: + results.append(await future) + return results + + def process_blocking() -> list[ValueMesh[R]]: + return [future.get() for future in futures] + + return Future(process, process_blocking) + + async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]: + """Broadcasts to all actors in all stacked endpoints and yields responses as a stream.""" + for endpoint in self.endpoints: + async for result in endpoint.stream(*args, **kwargs): + yield result + + def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: + """Fire-and-forget broadcast to all actors in all stacked endpoints.""" + for endpoint in self.endpoints: + endpoint.broadcast(*args, **kwargs) + + class Accumulator(Generic[P, R, A]): def __init__( - self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A] + self, + endpoint: EndpointInterface[P, R], + identity: A, + combine: Callable[[A, R], A], ) -> None: - self._endpoint: Endpoint[P, R] = endpoint + self._endpoint: EndpointInterface[P, R] = endpoint self._identity: A = identity self._combine: Callable[[A, R], A] = combine @@ -772,3 +858,124 @@ def current_rank() -> Point: def current_size() -> Dict[str, int]: ctx = MonarchContext.get() return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes)) + + +class StackedActorMeshRef(MeshTrait, Generic[T]): + def __init__(self, *actors: ActorMeshRef[T], interface=None) -> None: + self._actors = actors + self._interface = interface + + # Create endpoints for all methods in the interface + for attr_name in dir(interface): + attr_value = getattr(interface, attr_name, None) + if isinstance(attr_value, EndpointProperty): + # Get the corresponding endpoint from each mesh + endpoints = [] + for mesh in self._actors: + if hasattr(mesh, attr_name): + endpoints.append(getattr(mesh, attr_name)) + + # Create a stacked endpoint with all the collected endpoints + if endpoints: + setattr(self, attr_name, StackedEndpoint(endpoints)) + + @property + def _ndslice(self) -> NDSlice: + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + @property + def _labels(self) -> Tuple[str, ...]: + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + +def _common_ancestor(*actors: ActorMeshRef[T]) -> Optional[Type]: + """Finds the common ancestor class of a list of actor mesh references. + + This determines the most specific common base class shared by all + provided actors. + + Args: + *actors: Variable number of ActorMeshRef instances to analyze + + Returns: + Optional[Type]: The most specific common ancestor class, or None if no + actors were provided or no common ancestor exists + + Example: + ```python + # Find common ancestor of two counter actors + counter_a = proc.spawn("counter_a", CounterA, 0).get() + counter_b = proc.spawn("counter_b", CounterB, 0).get() + common_class = _common_ancestor(counter_a, counter_b) # Returns Counter + ``` + """ + if not actors: + return None + base_classes = [obj._class for obj in actors] + all_mros = [inspect.getmro(cls) for cls in base_classes] + common_bases = set(all_mros[0]).intersection(*all_mros[1:]) + if common_bases: + return min( + common_bases, key=lambda cls: min(mro.index(cls) for mro in all_mros) + ) + return None + + +def stack( + *actors: ActorMeshRef[T], interface: Optional[Type] = None +) -> StackedActorMeshRef[T]: + """Stacks multiple actor mesh references into a single unified interface. + + This allows you to combine multiple actors that share a common interface + into a single object that can be used to interact with all of them simultaneously. + When methods are called on the stacked actor, they are distributed to all + underlying actors according to the endpoint's behavior (choose, call, stream, etc). + + Args: + *actors: Variable number of ActorMeshRef instances to stack together + interface: Optional class that defines the interface to expose. If not provided, + the common ancestor class of all actors will be used. + + Returns: + StackedActorMeshRef: A reference that provides access to all stacked actors + through a unified interface. + + Raises: + TypeError: If any of the provided objects is not an ActorMeshRef, or if + no common ancestor can be found and no interface is provided. + + Example: + ```python + # Stack two counter actors together + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + stacked = stack(counter1, counter2) + + # Call methods on all actors at once + stacked.incr.broadcast() # Increments both counters + ``` + + """ + for actor in actors: + if not isinstance(actor, ActorMeshRef): + raise TypeError( + "stack be provided with Monarch Actors, got {}".format(type(actor)) + ) + if interface is None: + interface = _common_ancestor(*actors) + + if interface is None or interface == Actor: + raise TypeError( + "No common ancestor found for the given actors. Please provide an interface explicitly." + ) + logging.debug("Stacking actors %s with interface %s", actors, interface) + return StackedActorMeshRef(*actors, interface=interface) diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 5e167d02..a2af3480 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -19,13 +19,16 @@ import torch from monarch.actor_mesh import ( + _common_ancestor, Accumulator, Actor, current_actor_name, current_rank, current_size, endpoint, + EndpointInterface, MonarchContext, + StackedActorMeshRef, ) from monarch.debugger import init_debugging from monarch.future import ActorFuture @@ -735,3 +738,317 @@ async def neverfinish(): with pytest.raises(asyncio.exceptions.TimeoutError): f.get(timeout=0.1) + + +class CounterA(Counter): + def __init__(self, v: int, step: int = 1): + super().__init__(v) + self.step = step + + @endpoint + async def decr(self): + """Decrement the counter by step. + This is a function that is unique to CounterA. + """ + self.v -= self.step + + @endpoint + async def incr(self): + self.v += self.step + + +class CounterB(Counter): + def __init__(self, v: int, multiplier: int = 2): + super().__init__(v) + self.multiplier = multiplier + + @endpoint + async def reset(self): + """Reset the counter to 0. + This is a function that is unique to CounterB. + """ + self.v = 0 + + @endpoint + async def incr(self): + self.v *= self.multiplier + + +class CounterC(Actor): + def __init__(self, v: int, increment: int = 1): + self.v = v + self.increment = increment + + @endpoint + async def incr(self): + self.v += self.increment + + @endpoint + async def value(self) -> int: + return self.v + + @endpoint + async def double(self): + self.v *= 2 + + +class CounterD(CounterA): + def __init__(self, v: int, step: int = 1, factor: int = 3): + super().__init__(v, step) + self.factor = factor + + @endpoint + async def multiply(self): + """Multiply the counter by factor. + This is a function that is unique to CounterD. + """ + self.v *= self.factor + + @endpoint + async def decr(self): + """Override decr to decrement by step * factor.""" + self.v -= self.step * self.factor + + +def test_common_ancestor(): + proc = local_proc_mesh(gpus=1).get() + + # Test with same class + counter1 = proc.spawn("counter1", Counter, 0).get() + counter2 = proc.spawn("counter2", Counter, 0).get() + assert _common_ancestor(counter1, counter2) == Counter + + # Test with parent-child relationship + counter_a = proc.spawn("counter_a", CounterA, 0).get() + assert _common_ancestor(counter1, counter_a) == Counter + + # Test with siblings + counter_b = proc.spawn("counter_b", CounterB, 0).get() + assert _common_ancestor(counter_a, counter_b) == Counter + + # Test with unrelated classes + counter_c = proc.spawn("counter_c", CounterC, 0).get() + assert _common_ancestor(counter_a, counter_c) == Actor + + # Test with empty list + assert _common_ancestor() is None + + # Test with mixed hierarchy + assert _common_ancestor(counter1, counter_a, counter_b) == Counter + assert _common_ancestor(counter_a, counter_b, counter_c) == Actor + + +def test_identical_actor_stack(): + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + assert stacked is not None + assert isinstance(stacked, StackedActorMeshRef) + + stacked.incr.call().get() + assert counter1.value.choose().get() == 1 + assert counter2.value.choose().get() == 1 + + +def test_heterogeneous_actor_stack(): + """Test stacking actors of different types that share a common ancestor.""" + proc = local_proc_mesh(gpus=1).get() + + # Create different types of counters + counter = proc.spawn("counter", Counter, 0).get() + counter_a = proc.spawn("counter_a", CounterA, 0).get() + counter_b = proc.spawn("counter_b", CounterB, 0).get() + + # Stack them together - they should use Counter as the common interface + stacked = monarch.actor_mesh.stack(counter, counter_a, counter_b) + + # Verify the stacked actor has the common endpoints + assert hasattr(stacked, "incr") + assert hasattr(stacked, "value") + + # Verify unique endpoints are not accessible on the stacked actor + assert not hasattr(stacked, "decr") # CounterA specific + assert not hasattr(stacked, "reset") # CounterB specific + + # Test that the common endpoints work + stacked.incr.call().get() + + # Verify each actor was affected according to its implementation + assert counter.value.choose().get() == 1 # Regular counter: +1 + assert counter_a.value.choose().get() == 1 # CounterA: +step (default 1) + assert counter_b.value.choose().get() == 0 # CounterB: *multiplier (default 2) + + +def test_stack_with_custom_interface(): + """Test stacking actors with a specified interface.""" + proc = local_proc_mesh(gpus=1).get() + + # Create different types of counters + counter_a = proc.spawn("counter_a", CounterA, 0).get() + counter_d = proc.spawn("counter_d", CounterD, 0).get() + + # Without specifying interface, they would use CounterA as common ancestor + # But we want to use Counter interface instead + stacked = monarch.actor_mesh.stack(counter_a, counter_d, interface=Counter) + + # Verify the stacked actor has only Counter endpoints + assert hasattr(stacked, "incr") + assert hasattr(stacked, "value") + + # Verify CounterA/CounterD specific endpoints are not accessible + assert not hasattr(stacked, "decr") # Should not be available + assert not hasattr(stacked, "multiply") # CounterD specific + + # Test that the common endpoints work + stacked.incr.call().get() + + # Verify each actor was affected according to its implementation + assert counter_a.value.choose().get() == 1 # CounterA: +step (default 1) + assert counter_d.value.choose().get() == 1 # CounterD: +step (default 1) + + +def test_stacked_endpoint_consistency(): + """Tests that the StackedEndpoint shares the same APIs as EndPoint.""" + + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + regular_endpoint = counter1.incr + stacked_endpoint = stacked.incr + + # Verify both endpoints implement all methods from EndpointInterface + for method_name in [ + name + for name in dir(EndpointInterface) + if not name.startswith("_") and callable(getattr(EndpointInterface, name)) + ]: + assert hasattr(regular_endpoint, method_name), f"Endpoint missing {method_name}" + assert hasattr( + stacked_endpoint, method_name + ), f"StackedEndpoint missing {method_name}" + + +def test_stacked_endpoint_choose(): + """Tests that the StackedEndpoint.choose method works correctly.""" + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + stacked_endpoint = stacked.incr + + # Test choose + stacked_endpoint.choose().get() + # At least one counter should be incremented + assert counter1.value.choose().get() + counter2.value.choose().get() >= 1 + + +def test_stacked_endpoint_call(): + """Tests that the StackedEndpoint.call method works correctly.""" + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + stacked_endpoint = stacked.incr + + # Test call + result = stacked_endpoint.call().get() + assert isinstance(result, list) + assert len(result) == 2 + + # Verify both counters were incremented + assert counter1.value.choose().get() == 1 + assert counter2.value.choose().get() == 1 + + +def test_stacked_endpoint_broadcast(): + """Tests that the StackedEndpoint.broadcast method works correctly.""" + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + stacked_endpoint = stacked.incr + + # Test broadcast + stacked_endpoint.broadcast() + # Both counters should be incremented + assert counter1.value.choose().get() == 1 + assert counter2.value.choose().get() == 1 + + +async def test_stacked_endpoint_stream(): + """Tests that the StackedEndpoint.stream method works correctly.""" + proc1 = await local_proc_mesh(gpus=1) + proc2 = await local_proc_mesh(gpus=1) + counter1 = await proc1.spawn("counter1", Counter, 0) + counter2 = await proc2.spawn("counter2", Counter, 0) + + stacked = monarch.actor_mesh.stack(counter1, counter2) + stacked_endpoint = stacked.incr + + # Test stream + async def test_stream(): + results = [] + async for _ in stacked_endpoint.stream(): + results.append(1) + return results + + results = await test_stream() + assert len(results) == 2 + + # Verify both counters were incremented + assert counter1.value.choose().get() == 1 + assert counter2.value.choose().get() == 1 + + +def test_stacked_endpoint_interface_compliance(): + """Tests that StackedEndpoint implements all methods from EndpointInterface.""" + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + counter1 = proc1.spawn("counter1", Counter, 0).get() + counter2 = proc2.spawn("counter2", Counter, 0).get() + + stacked = monarch.actor_mesh.stack(counter1, counter2) + regular_endpoint = counter1.incr + stacked_endpoint = stacked.incr + + # Verify both endpoints implement all methods from EndpointInterface + for method_name in [ + name + for name in dir(EndpointInterface) + if not name.startswith("_") and callable(getattr(EndpointInterface, name)) + ]: + assert hasattr(regular_endpoint, method_name), f"Endpoint missing {method_name}" + assert hasattr( + stacked_endpoint, method_name + ), f"StackedEndpoint missing {method_name}" + + +def test_stacked_actor_with_accumulator(): + """Tests that Accumulator works correctly with StackedActor endpoints.""" + proc1 = local_proc_mesh(gpus=1).get() + proc2 = local_proc_mesh(gpus=1).get() + + counter1 = proc1.spawn("counter1", Counter, 5).get() + counter2 = proc2.spawn("counter2", Counter, 10).get() + stacked = monarch.actor_mesh.stack(counter1, counter2) + + acc = Accumulator(stacked.value, 0, operator.add) + result = acc.accumulate().get() + assert result == 15 + + stacked.incr.broadcast() + result = acc.accumulate().get() + assert result == 17