diff --git a/monarch_extension/src/mesh_controller.rs b/monarch_extension/src/mesh_controller.rs index b8253eea..e3b8affb 100644 --- a/monarch_extension/src/mesh_controller.rs +++ b/monarch_extension/src/mesh_controller.rs @@ -511,7 +511,7 @@ impl History { .getattr("RemoteException") .unwrap(); let pickle = py - .import("monarch.actor_mesh") + .import("monarch._src.actor.actor_mesh") .unwrap() .getattr("_pickle") .unwrap(); diff --git a/monarch_tensor_worker/src/stream.rs b/monarch_tensor_worker/src/stream.rs index c477b513..1ee4e7a7 100644 --- a/monarch_tensor_worker/src/stream.rs +++ b/monarch_tensor_worker/src/stream.rs @@ -990,7 +990,7 @@ impl StreamActor { }) .and_then(|result| -> Result { let pickle = py - .import("monarch.actor_mesh") + .import("monarch._src.actor.actor_mesh") .unwrap() .getattr("_pickle") .unwrap(); diff --git a/python/monarch/__init__.py b/python/monarch/__init__.py index 0a7a24af..7a0b1330 100644 --- a/python/monarch/__init__.py +++ b/python/monarch/__init__.py @@ -29,7 +29,8 @@ if TYPE_CHECKING: from monarch import timer - from monarch.allocator import LocalAllocator, ProcessAllocator + from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator + from monarch._src.actor.shape import NDSlice, Shape from monarch.common._coalescing import coalescing from monarch.common.device_mesh import ( @@ -50,11 +51,9 @@ from monarch.common.pipe import create_pipe, Pipe, remote_generator from monarch.common.remote import remote from monarch.common.selection import Selection - from monarch.common.shape import NDSlice, Shape from monarch.common.stream import get_active_stream, Stream from monarch.common.tensor import reduce, reduce_, Tensor from monarch.fetch import fetch_shard, inspect, show - from monarch.future import ActorFuture from monarch.gradient_generator import grad_function, grad_generator from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve from monarch.python_local_mesh import python_local_mesh @@ -79,8 +78,8 @@ "function_resolvers": ("monarch.common.function", "resolvers"), "Future": ("monarch.common.future", "Future"), "RemoteException": ("monarch.common.invocation", "RemoteException"), - "Shape": ("monarch.common.shape", "Shape"), - "NDSlice": ("monarch.common.shape", "NDSlice"), + "Shape": ("monarch._src.actor.shape", "Shape"), + "NDSlice": ("monarch._src.actor.shape", "NDSlice"), "Selection": ("monarch.common.selection", "Selection"), "OpaqueRef": ("monarch.common.opaque_ref", "OpaqueRef"), "create_pipe": ("monarch.common.pipe", "create_pipe"), @@ -112,9 +111,8 @@ "Simulator": ("monarch.simulator.interface", "Simulator"), "world_mesh": ("monarch.world_mesh", "world_mesh"), "timer": ("monarch.timer", "timer"), - "ProcessAllocator": ("monarch.allocator", "ProcessAllocator"), - "LocalAllocator": ("monarch.allocator", "LocalAllocator"), - "ActorFuture": ("monarch.future", "ActorFuture"), + "ProcessAllocator": ("monarch._src.actor.allocator", "ProcessAllocator"), + "LocalAllocator": ("monarch._src.actor.allocator", "LocalAllocator"), "builtins": ("monarch.builtins", "builtins"), } @@ -183,7 +181,6 @@ def __getattr__(name): "timer", "ProcessAllocator", "LocalAllocator", - "ActorFuture", "builtins", ] assert sorted(__all__) == sorted(_public_api) diff --git a/python/monarch/_rust_bindings/__init__.pyi b/python/monarch/_rust_bindings/__init__.pyi index e69de29b..18fe7133 100644 --- a/python/monarch/_rust_bindings/__init__.pyi +++ b/python/monarch/_rust_bindings/__init__.pyi @@ -0,0 +1,8 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# True iff the rust extension was built with the tensor engine feature. +def has_tensor_engine() -> bool: ... diff --git a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi index c77337c2..30840ac9 100644 --- a/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi +++ b/python/monarch/_rust_bindings/monarch_hyperactor/alloc.pyi @@ -73,7 +73,7 @@ class RemoteAllocatorBase: def __new__( cls, world_id: str, - initializer: "monarch.allocator.RemoteAllocInitializer", # pyre-ignore[11] defined in monarch/python/monarch/allocator.py + initializer: "monarch._src.actor.allocator.RemoteAllocInitializer", # pyre-ignore[11] heartbeat_interval: timedelta = timedelta(seconds=5), ) -> Self: """ diff --git a/python/monarch/_src/__init__.py b/python/monarch/_src/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/monarch/_src/actor/__init__.py b/python/monarch/_src/actor/__init__.py new file mode 100644 index 00000000..f29d648d --- /dev/null +++ b/python/monarch/_src/actor/__init__.py @@ -0,0 +1,9 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Monarch Actor API +""" diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py new file mode 100644 index 00000000..687add3e --- /dev/null +++ b/python/monarch/_src/actor/actor_mesh.py @@ -0,0 +1,793 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +import collections +import contextvars +import functools +import inspect +import logging +import random +import sys +import traceback + +from dataclasses import dataclass +from traceback import extract_tb, StackSummary +from typing import ( + Any, + AsyncGenerator, + Awaitable, + Callable, + cast, + Concatenate, + Dict, + Generic, + Iterable, + List, + Literal, + NamedTuple, + Optional, + ParamSpec, + Tuple, + Type, + TYPE_CHECKING, + TypeVar, +) + +from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span + +from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage +from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh +from monarch._rust_bindings.monarch_hyperactor.mailbox import ( + Mailbox, + OncePortReceiver, + OncePortRef, + PortReceiver as HyPortReceiver, + PortRef, +) +from monarch._rust_bindings.monarch_hyperactor.proc import ActorId +from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape +from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator +from monarch._src.actor.future import Future +from monarch._src.actor.pdb_wrapper import remote_breakpointhook + +from monarch._src.actor.pickle import flatten, unpickle + +from monarch._src.actor.shape import MeshTrait, NDSlice + +if TYPE_CHECKING: + from monarch._src.actor.debugger import DebugClient + +logger: logging.Logger = logging.getLogger(__name__) + +Allocator = ProcessAllocator | LocalAllocator + +try: + from __manifest__ import fbmake # noqa + + IN_PAR = True +except ImportError: + IN_PAR = False + +T1 = TypeVar("T1") +T2 = TypeVar("T2") + + +class Point(HyPoint, collections.abc.Mapping): + pass + + +@dataclass +class MonarchContext: + mailbox: Mailbox + proc_id: str + point: Point + + @staticmethod + def get() -> "MonarchContext": + return _context.get() + + +_context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar( + "monarch.actor_mesh._context" +) + + +T = TypeVar("T") +P = ParamSpec("P") +R = TypeVar("R") +A = TypeVar("A") + +# keep this load balancing deterministic, but +# equally distributed. +_load_balancing_seed = random.Random(4) + + +Selection = Literal["all", "choose"] # TODO: replace with real selection objects + + +# standin class for whatever is the serializable python object we use +# to name an actor mesh. Hacked up today because ActorMesh +# isn't plumbed to non-clients +class _ActorMeshRefImpl: + def __init__( + self, + mailbox: Mailbox, + hy_actor_mesh: Optional[PythonActorMesh], + shape: Shape, + actor_ids: List[ActorId], + ) -> None: + self._mailbox = mailbox + self._actor_mesh = hy_actor_mesh + self._shape = shape + self._please_replace_me_actor_ids = actor_ids + + @staticmethod + def from_hyperactor_mesh( + mailbox: Mailbox, hy_actor_mesh: PythonActorMesh + ) -> "_ActorMeshRefImpl": + shape: Shape = hy_actor_mesh.shape + return _ActorMeshRefImpl( + mailbox, + hy_actor_mesh, + hy_actor_mesh.shape, + [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], + ) + + @staticmethod + def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": + return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id]) + + @staticmethod + def from_actor_ref_with_shape( + ref: "_ActorMeshRefImpl", shape: Shape + ) -> "_ActorMeshRefImpl": + return _ActorMeshRefImpl( + ref._mailbox, None, shape, ref._please_replace_me_actor_ids + ) + + def __getstate__( + self, + ) -> Tuple[Shape, List[ActorId], Mailbox]: + return self._shape, self._please_replace_me_actor_ids, self._mailbox + + def __setstate__( + self, + state: Tuple[Shape, List[ActorId], Mailbox], + ) -> None: + self._actor_mesh = None + self._shape, self._please_replace_me_actor_ids, self._mailbox = state + + def send(self, rank: int, message: PythonMessage) -> None: + actor = self._please_replace_me_actor_ids[rank] + self._mailbox.post(actor, message) + + def cast( + self, + message: PythonMessage, + selection: Selection, + ) -> None: + # TODO: use the actual actor mesh when available. We cannot currently use it + # directly because we risk bifurcating the message delivery paths from the same + # client, since slicing the mesh will produce a reference, which calls actors + # directly. The reason these paths are bifurcated is that actor meshes will + # use multicasting, while direct actor comms do not. Separately we need to decide + # whether actor meshes are ordered with actor references. + # + # The fix is to provide a first-class reference into Python, and always call "cast" + # on it, including for load balanced requests. + if selection == "choose": + idx = _load_balancing_seed.randrange(len(self._shape)) + actor_rank = self._shape.ndslice[idx] + self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message) + return + elif selection == "all": + # replace me with actual remote actor mesh + call_shape = Shape( + self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes) + ) + for i, rank in enumerate(self._shape.ranks()): + self._mailbox.post_cast( + self._please_replace_me_actor_ids[rank], + i, + call_shape, + message, + ) + else: + raise ValueError(f"invalid selection: {selection}") + + def __len__(self) -> int: + return len(self._shape) + + +class Endpoint(Generic[P, R]): + def __init__( + self, + actor_mesh_ref: _ActorMeshRefImpl, + name: str, + impl: Callable[Concatenate[Any, P], Awaitable[R]], + mailbox: Mailbox, + ) -> None: + self._actor_mesh = actor_mesh_ref + self._name = name + self._signature: inspect.Signature = inspect.signature(impl) + self._mailbox = mailbox + + # the following are all 'adverbs' or different ways to handle the + # return values of this endpoint. Adverbs should only ever take *args, **kwargs + # of the original call. If we want to add syntax sugar for something that needs additional + # arguments, it should be implemented as function indepdendent of endpoint like `send` + # and `Accumulator` + def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: + """ + Load balanced sends a message to one chosen actor and awaits a result. + + Load balanced RPC-style entrypoint for request/response messaging. + """ + p: Port[R] + r: PortReceiver[R] + p, r = port(self, once=True) + # pyre-ignore + send(self, args, kwargs, port=p, selection="choose") + return r.recv() + + def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: + if len(self._actor_mesh) != 1: + raise ValueError( + f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}" + ) + return self.choose(*args, **kwargs) + + def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": + p: Port[R] + r: RankedPortReceiver[R] + p, r = ranked_port(self) + # pyre-ignore + send(self, args, kwargs, port=p) + + async def process() -> ValueMesh[R]: + results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] + for _ in range(len(self._actor_mesh)): + rank, value = await r.recv() + results[rank] = value + call_shape = Shape( + self._actor_mesh._shape.labels, + NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), + ) + return ValueMesh(call_shape, results) + + def process_blocking() -> ValueMesh[R]: + results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] + for _ in range(len(self._actor_mesh)): + rank, value = r.recv().get() + results[rank] = value + call_shape = Shape( + self._actor_mesh._shape.labels, + NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), + ) + return ValueMesh(call_shape, results) + + return Future(process, process_blocking) + + async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]: + """ + Broadcasts to all actors and yields their responses as a stream / generator. + + This enables processing results from multiple actors incrementally as + they become available. Returns an async generator of response values. + """ + p, r = port(self) + # pyre-ignore + send(self, args, kwargs, port=p) + for _ in range(len(self._actor_mesh)): + yield await r.recv() + + def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: + """ + Fire-and-forget broadcast to all actors without waiting for actors to + acknowledge receipt. + + In other words, the return of this method does not guarrantee the + delivery of the message. + """ + # pyre-ignore + send(self, args, kwargs) + + +class Accumulator(Generic[P, R, A]): + def __init__( + self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A] + ) -> None: + self._endpoint: Endpoint[P, R] = endpoint + self._identity: A = identity + self._combine: Callable[[A, R], A] = combine + + def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]": + gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs) + + async def impl() -> A: + value = self._identity + async for x in gen: + value = self._combine(value, x) + return value + + return Future(impl) + + +class ValueMesh(MeshTrait, Generic[R]): + """ + Container of return values, indexed by rank. + """ + + def __init__(self, shape: Shape, values: List[R]) -> None: + self._shape = shape + self._values = values + + def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]": + return ValueMesh(shape, self._values) + + def item(self, **kwargs) -> R: + coordinates = [kwargs.pop(label) for label in self._labels] + if kwargs: + raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}") + + return self._values[self._ndslice.nditem(coordinates)] + + def __iter__(self): + for rank in self._shape.ranks(): + yield Point(rank, self._shape), self._values[rank] + + def __len__(self) -> int: + return len(self._shape) + + def __repr__(self) -> str: + return f"ValueMesh({self._shape})" + + @property + def _ndslice(self) -> NDSlice: + return self._shape.ndslice + + @property + def _labels(self) -> Iterable[str]: + return self._shape.labels + + +def send( + endpoint: Endpoint[P, R], + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + port: "Optional[Port]" = None, + selection: Selection = "all", +) -> None: + """ + Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh. + + This sends the message to all actors but does not wait for any result. + """ + endpoint._signature.bind(None, *args, **kwargs) + message = PythonMessage( + endpoint._name, + _pickle((args, kwargs)), + None if port is None else port._port_ref, + None, + ) + endpoint._actor_mesh.cast(message, selection) + + +class EndpointProperty(Generic[P, R]): + def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: + self._method = method + + def __get__(self, instance, owner) -> Endpoint[P, R]: + # this is a total lie, but we have to actually + # recognize this was defined as an endpoint, + # and also lookup the method + return cast(Endpoint[P, R], self) + + +def endpoint( + method: Callable[Concatenate[Any, P], Awaitable[R]], +) -> EndpointProperty[P, R]: + return EndpointProperty(method) + + +class Port(Generic[R]): + def __init__( + self, port_ref: PortRef | OncePortRef, mailbox: Mailbox, rank: Optional[int] + ) -> None: + self._port_ref = port_ref + self._mailbox = mailbox + self._rank = rank + + def send(self, method: str, obj: R) -> None: + self._port_ref.send( + self._mailbox, + PythonMessage(method, _pickle(obj), None, self._rank), + ) + + +R = TypeVar("R") + +T = TypeVar("T") + +if TYPE_CHECKING: + # Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time. + # we only need it for type checking though, so copypasta it until 3.11. + class PortTuple(NamedTuple, Generic[R]): + sender: "Port[R]" + receiver: "PortReceiver[R]" + + @staticmethod + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() + port_ref = handle.bind() + return PortTuple( + Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + ) +else: + + class PortTuple(NamedTuple): + sender: "Port[Any]" + receiver: "PortReceiver[Any]" + + @staticmethod + def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": + handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() + port_ref = handle.bind() + return PortTuple( + Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) + ) + + +# advance lower-level API for sending messages. This is intentially +# not part of the Endpoint API because they way it accepts arguments +# and handles concerns is different. +def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": + return PortTuple.create(endpoint._mailbox, once) + + +def ranked_port( + endpoint: Endpoint[P, R], once: bool = False +) -> Tuple["Port[R]", "RankedPortReceiver[R]"]: + p, receiver = port(endpoint, once) + return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver) + + +class PortReceiver(Generic[R]): + def __init__( + self, + mailbox: Mailbox, + receiver: HyPortReceiver | OncePortReceiver, + ) -> None: + self._mailbox: Mailbox = mailbox + self._receiver: HyPortReceiver | OncePortReceiver = receiver + + async def _recv(self) -> R: + return self._process(await self._receiver.recv()) + + def _blocking_recv(self) -> R: + return self._process(self._receiver.blocking_recv()) + + def _process(self, msg: PythonMessage) -> R: + # TODO: Try to do something more structured than a cast here + payload = cast(R, unpickle(msg.message, self._mailbox)) + if msg.method == "result": + return payload + else: + assert msg.method == "exception" + # pyre-ignore + raise payload + + def recv(self) -> "Future[R]": + return Future(lambda: self._recv(), self._blocking_recv) + + +class RankedPortReceiver(PortReceiver[Tuple[int, R]]): + def _process(self, msg: PythonMessage) -> Tuple[int, R]: + if msg.rank is None: + raise ValueError("RankedPort receiver got a message without a rank") + return msg.rank, super()._process(msg) + + +singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[])) + + +class _Actor: + """ + This is the message handling implementation of a Python actor. + + The layering goes: + Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance + + Messages are received from the Rust backend, and forwarded to the `handle` + methods on this class. + + This class wraps the actual `Actor` instance provided by the user, and + routes messages to it, managing argument serialization/deserialization and + error handling. + """ + + def __init__(self) -> None: + self.instance: object | None = None + + async def handle( + self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag + ) -> None: + return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag) + + async def handle_cast( + self, + mailbox: Mailbox, + rank: int, + shape: Shape, + message: PythonMessage, + panic_flag: PanicFlag, + ) -> None: + port = ( + Port(message.response_port, mailbox, rank) + if message.response_port + else None + ) + try: + ctx: MonarchContext = MonarchContext( + mailbox, mailbox.actor_id.proc_id, Point(rank, shape) + ) + _context.set(ctx) + + args, kwargs = unpickle(message.message, mailbox) + + if message.method == "__init__": + Class, *args = args + self.instance = Class(*args, **kwargs) + return None + + if self.instance is None: + # This could happen because of the following reasons. Both + # indicates a possible bug in the framework: + # 1. the execution of the previous message for "__init__" failed, + # but that error is not surfaced to the caller. + # - TODO(T229200522): there is a known bug. fix it. + # 2. this message is delivered to this actor before the previous + # message of "__init__" is delivered. Out-of-order delivery + # should never happen. It indicates either a bug in the + # message delivery mechanism, or the framework accidentally + # mixed the usage of cast and direct send. + raise AssertionError( + f""" + actor object is missing when executing method {message.method} + on actor {mailbox.actor_id} + """ + ) + the_method = getattr(self.instance, message.method)._method + + if inspect.iscoroutinefunction(the_method): + + async def instrumented(): + enter_span( + the_method.__module__, + message.method, + str(ctx.mailbox.actor_id), + ) + try: + result = await the_method(self.instance, *args, **kwargs) + except Exception as e: + logging.critical( + "Unahndled exception in actor endpoint", + exc_info=e, + ) + raise e + exit_span() + return result + + result = await instrumented() + else: + enter_span( + the_method.__module__, message.method, str(ctx.mailbox.actor_id) + ) + result = the_method(self.instance, *args, **kwargs) + exit_span() + + if port is not None: + port.send("result", result) + except Exception as e: + traceback.print_exc() + s = ActorError(e) + + # The exception is delivered to exactly one of: + # (1) our caller, (2) our supervisor + if port is not None: + port.send("exception", s) + else: + raise s from None + except BaseException as e: + # A BaseException can be thrown in the case of a Rust panic. + # In this case, we need a way to signal the panic to the Rust side. + # See [Panics in async endpoints] + try: + panic_flag.signal_panic(e) + except Exception: + # The channel might be closed if the Rust side has already detected the error + pass + raise + + +def _is_mailbox(x: object) -> bool: + return isinstance(x, Mailbox) + + +def _pickle(obj: object) -> bytes: + _, msg = flatten(obj, _is_mailbox) + return msg + + +class Actor(MeshTrait): + @functools.cached_property + def logger(cls) -> logging.Logger: + lgr = logging.getLogger(cls.__class__.__name__) + lgr.setLevel(logging.DEBUG) + return lgr + + @property + def _ndslice(self) -> NDSlice: + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + @property + def _labels(self) -> Tuple[str, ...]: + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": + raise NotImplementedError( + "actor implementations are not meshes, but we can't convince the typechecker of it..." + ) + + @endpoint # pyre-ignore + def _set_debug_client(self, client: "DebugClient") -> None: + point = MonarchContext.get().point + # For some reason, using a lambda instead of functools.partial + # confuses the pdb wrapper implementation. + sys.breakpointhook = functools.partial( # pyre-ignore + remote_breakpointhook, + point.rank, + point.shape.coordinates(point.rank), + MonarchContext.get().mailbox.actor_id, + client, + ) + + +class ActorMeshRef(MeshTrait, Generic[T]): + def __init__( + self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox + ) -> None: + self.__name__: str = Class.__name__ + self._class: Type[T] = Class + self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref + self._mailbox: Mailbox = mailbox + for attr_name in dir(self._class): + attr_value = getattr(self._class, attr_name, None) + if isinstance(attr_value, EndpointProperty): + setattr( + self, + attr_name, + Endpoint( + self._actor_mesh_ref, + attr_name, + attr_value._method, + self._mailbox, + ), + ) + + def __getattr__(self, name: str) -> Any: + # This method is called when an attribute is not found + # For linting purposes, we need to tell the type checker that any attribute + # could be an endpoint that's dynamically added at runtime + # At runtime, we still want to raise AttributeError for truly missing attributes + + # Check if this is a method on the underlying class + if hasattr(self._class, name): + attr = getattr(self._class, name) + if isinstance(attr, EndpointProperty): + # Dynamically create the endpoint + endpoint = Endpoint( + self._actor_mesh_ref, + name, + attr._method, + self._mailbox, + ) + # Cache it for future use + setattr(self, name, endpoint) + return endpoint + + # If we get here, it's truly not found + raise AttributeError( + f"'{self.__class__.__name__}' object has no attribute '{name}'" + ) + + def _create( + self, + args: Iterable[Any], + kwargs: Dict[str, Any], + ) -> None: + async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: + return None + + ep = Endpoint( + self._actor_mesh_ref, + "__init__", + null_func, + self._mailbox, + ) + send(ep, (self._class, *args), kwargs) + + def __reduce_ex__( + self, protocol: ... + ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]": + return ActorMeshRef, ( + self._class, + self._actor_mesh_ref, + self._mailbox, + ) + + @property + def _ndslice(self) -> NDSlice: + return self._actor_mesh_ref._shape.ndslice + + @property + def _labels(self) -> Iterable[str]: + return self._actor_mesh_ref._shape.labels + + def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": + return ActorMeshRef( + self._class, + _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape), + self._mailbox, + ) + + def __repr__(self) -> str: + return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})" + + +class ActorError(Exception): + """ + Deterministic problem with the user's code. + For example, an OOM resulting in trying to allocate too much GPU memory, or violating + some invariant enforced by the various APIs. + """ + + def __init__( + self, + exception: Exception, + message: str = "A remote actor call has failed asynchronously.", + ) -> None: + self.exception = exception + self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__) + self.message = message + + def __str__(self) -> str: + exe = str(self.exception) + actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames)) + return ( + f"{self.message}\n" + f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}" + ) + + +def current_actor_name() -> str: + return str(MonarchContext.get().mailbox.actor_id) + + +def current_rank() -> Point: + ctx = MonarchContext.get() + return ctx.point + + +def current_size() -> Dict[str, int]: + ctx = MonarchContext.get() + return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes)) diff --git a/python/monarch/allocator.py b/python/monarch/_src/actor/allocator.py similarity index 99% rename from python/monarch/allocator.py rename to python/monarch/_src/actor/allocator.py index cf39b6db..552fba71 100644 --- a/python/monarch/allocator.py +++ b/python/monarch/_src/actor/allocator.py @@ -10,7 +10,6 @@ import logging from typing import final, Optional -from monarch import ActorFuture as Future from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension Alloc, AllocSpec, @@ -22,6 +21,8 @@ RemoteAllocatorBase, ) +from monarch._src.actor.future import Future + ALLOC_LABEL_PROC_MESH_NAME = "procmesh.monarch.meta.com/name" logger: logging.Logger = logging.getLogger(__name__) diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py new file mode 100644 index 00000000..5b377ac2 --- /dev/null +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +This is the main function for the boostrapping a new process using a ProcessAllocator. +""" + +import asyncio +import importlib.resources +import logging +import os +import sys + +# Import torch to avoid import-time races if a spawned actor tries to import torch. +try: + import torch # @manual +except ImportError: + pass + + +async def main(): + from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main + + await bootstrap_main() + + +def invoke_main(): + # if this is invoked with the stdout piped somewhere, then print + # changes its buffering behavior. So we default to the standard + # behavior of std out as if it were a terminal. + sys.stdout.reconfigure(line_buffering=True) + global bootstrap_main + + # TODO: figure out what from worker_main.py we should reproduce here. + from monarch._src.actor.telemetry import TracingForwarder + + if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1": + raise RuntimeError("Error during bootstrap for testing") + + # forward logs to rust tracing. Defaults to on. + if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1": + logging.root.addHandler(TracingForwarder(level=logging.DEBUG)) + + try: + with ( + importlib.resources.path("monarch", "py-spy") as pyspy, + ): + if pyspy.exists(): + os.environ["PYSPY_BIN"] = str(pyspy) + # fallback to using local py-spy + except Exception as e: + logging.warning(f"Failed to set up py-spy: {e}") + + # Start an event loop for PythonActors to use. + asyncio.run(main()) + + +if __name__ == "__main__": + invoke_main() # pragma: no cover diff --git a/python/monarch/code_sync/__init__.py b/python/monarch/_src/actor/code_sync/__init__.py similarity index 100% rename from python/monarch/code_sync/__init__.py rename to python/monarch/_src/actor/code_sync/__init__.py diff --git a/python/monarch/code_sync/auto_reload.py b/python/monarch/_src/actor/code_sync/auto_reload.py similarity index 98% rename from python/monarch/code_sync/auto_reload.py rename to python/monarch/_src/actor/code_sync/auto_reload.py index d5bce8cf..50994154 100644 --- a/python/monarch/code_sync/auto_reload.py +++ b/python/monarch/_src/actor/code_sync/auto_reload.py @@ -16,8 +16,8 @@ from types import ModuleType from typing import Dict, List, Optional, Tuple -from monarch.actor_mesh import Actor, endpoint -from monarch.code_sync import WorkspaceLocation +from monarch._src.actor.actor_mesh import Actor, endpoint +from monarch._src.actor.code_sync import WorkspaceLocation class SysAuditHookGuard(contextlib.AbstractContextManager): diff --git a/python/monarch/debugger.py b/python/monarch/_src/actor/debugger.py similarity index 98% rename from python/monarch/debugger.py rename to python/monarch/_src/actor/debugger.py index 97ac548a..2906c63f 100644 --- a/python/monarch/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -11,11 +11,9 @@ from typing import Dict, List, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch.actor_mesh import Actor, ActorMeshRef, endpoint - -from monarch.pdb_wrapper import DebuggerWrite - -from monarch.proc_mesh import local_proc_mesh +from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, endpoint +from monarch._src.actor.pdb_wrapper import DebuggerWrite +from monarch._src.actor.proc_mesh import local_proc_mesh from tabulate import tabulate diff --git a/python/monarch/common/_device_utils.py b/python/monarch/_src/actor/device_utils.py similarity index 100% rename from python/monarch/common/_device_utils.py rename to python/monarch/_src/actor/device_utils.py diff --git a/python/monarch/future.py b/python/monarch/_src/actor/future.py similarity index 98% rename from python/monarch/future.py rename to python/monarch/_src/actor/future.py index a13971d1..2f12c80a 100644 --- a/python/monarch/future.py +++ b/python/monarch/_src/actor/future.py @@ -28,7 +28,7 @@ async def _aincomplete(impl, self): # TODO: consolidate with monarch.common.future -class ActorFuture(Generic[R]): +class Future(Generic[R]): def __init__(self, impl, blocking_impl=None): if blocking_impl is None: blocking_impl = partial(asyncio.run, impl()) diff --git a/python/monarch/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py similarity index 98% rename from python/monarch/pdb_wrapper.py rename to python/monarch/_src/actor/pdb_wrapper.py index c200e146..87031dd4 100644 --- a/python/monarch/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -17,7 +17,7 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ActorId if TYPE_CHECKING: - from monarch.debugger import DebugClient + from monarch._src.actor.debugger import DebugClient @dataclass diff --git a/python/monarch/common/pickle_flatten.py b/python/monarch/_src/actor/pickle.py similarity index 52% rename from python/monarch/common/pickle_flatten.py rename to python/monarch/_src/actor/pickle.py index 557b5399..5b2167df 100644 --- a/python/monarch/common/pickle_flatten.py +++ b/python/monarch/_src/actor/pickle.py @@ -5,11 +5,17 @@ # LICENSE file in the root directory of this source tree. import io +import itertools import pickle +from contextlib import contextmanager, nullcontext from typing import Any, Callable, Iterable, List, Tuple import cloudpickle -import torch + +try: + import torch # @manual +except ImportError: + torch = None class _Pickler(cloudpickle.Pickler): @@ -45,6 +51,42 @@ def flatten(obj: Any, filter: Callable[[Any], bool]) -> Tuple[List[Any], bytes]: def unflatten(data: bytes, values: Iterable[Any]) -> Any: - with torch.utils._python_dispatch._disable_current_modes(): + if torch is not None: + context_manager = torch.utils._python_dispatch._disable_current_modes + else: + context_manager = nullcontext + + with context_manager(): up = _Unpickler(data, values) return up.load() + + +@contextmanager +def load_tensors_on_cpu(): + # Ensure that any tensors load from CPU via monkeypatching how Storages are + # loaded. + old = torch.storage._load_from_bytes + try: + torch.storage._load_from_bytes = lambda b: torch.load( + io.BytesIO(b), map_location="cpu", weights_only=False + ) + yield + finally: + torch.storage._load_from_bytes = old + + +def unpickle(data: bytes, mailbox) -> Any: + if torch is not None: + context_manager = load_tensors_on_cpu + else: + context_manager = nullcontext + + with context_manager(): + # regardless of the mailboxes of the remote objects + # they all become the local mailbox. + return unflatten(data, itertools.repeat(mailbox)) + + +def pickle_(obj: object, filter: Callable[[Any], bool]) -> bytes: + _, msg = flatten(obj, filter) + return msg diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py new file mode 100644 index 00000000..d2745013 --- /dev/null +++ b/python/monarch/_src/actor/proc_mesh.py @@ -0,0 +1,363 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import os +import sys +from contextlib import AbstractContextManager + +from typing import ( + Any, + cast, + Dict, + List, + Optional, + Sequence, + Type, + TYPE_CHECKING, + TypeVar, +) + +from monarch._rust_bindings import has_tensor_engine + +from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension + Alloc, + AllocConstraints, + AllocSpec, +) +from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox +from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( + ProcMesh as HyProcMesh, + ProcMeshMonitor, +) +from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice +from monarch._src.actor.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef +from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator +from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation +from monarch._src.actor.code_sync.auto_reload import AutoReloadActor + +from monarch._src.actor.device_utils import _local_device_count +from monarch._src.actor.future import Future +from monarch._src.actor.shape import MeshTrait + + +if TYPE_CHECKING: + Tensor = Any + DeviceMesh = Any + RDMAManager = Any + + +T = TypeVar("T") +try: + from __manifest__ import fbmake # noqa + + IN_PAR = True +except ImportError: + IN_PAR = False + + +async def _allocate_nonblocking(alloc: Alloc) -> "ProcMesh": + return ProcMesh(await HyProcMesh.allocate_nonblocking(alloc)) + + +def _allocate_blocking(alloc: Alloc) -> "ProcMesh": + return ProcMesh(HyProcMesh.allocate_blocking(alloc)) + + +class ProcMesh(MeshTrait): + def __init__( + self, + hy_proc_mesh: HyProcMesh, + _mock_shape: Optional[Shape] = None, + _device_mesh: Optional["DeviceMesh"] = None, + ) -> None: + self._proc_mesh = hy_proc_mesh + self._mock_shape: Optional[Shape] = _mock_shape + # type: ignore[21] + self._rdma_manager: Optional["RDMAManager"] = None + self._mailbox: Mailbox = self._proc_mesh.client + self._rsync_mesh_client: Optional[RsyncMeshClient] = None + self._auto_reload_actor: Optional[AutoReloadActor] = None + self._maybe_device_mesh: Optional["DeviceMesh"] = _device_mesh + self._stopped = False + if _mock_shape is None and has_tensor_engine(): + # type: ignore[21] + from monarch.rdma import RDMAManager # @manual + + # type: ignore[21] + self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + + @property + def _shape(self) -> Shape: + return self._proc_mesh.shape if self._mock_shape is None else self._mock_shape + + @property + def _ndslice(self) -> Slice: + return self._shape.ndslice + + @property + def _labels(self) -> List[str]: + return self._shape.labels + + def _new_with_shape(self, shape: Shape) -> "ProcMesh": + device_mesh = ( + None + if self._device_mesh is None + else self._device_mesh._new_with_shape(shape) + ) + return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh) + + def spawn( + self, name: str, Class: Type[T], *args: Any, **kwargs: Any + ) -> Future[ActorMeshRef[T]]: + if self._mock_shape is not None: + raise NotImplementedError("NYI: spawn on slice of a proc mesh.") + return Future( + lambda: self._spawn_nonblocking(name, Class, *args, **kwargs), + lambda: self._spawn_blocking(name, Class, *args, **kwargs), + ) + + async def monitor(self) -> ProcMeshMonitor: + """ + Get a monitor (async iterator) of the proc mesh, it is used to + monitor the status of the proc mesh. This function can be called at most once. + + Note: This API is experimental and subject to change. + + Example: + + async def monitor_loop(monitor): + async for event in monitor: + await handle_exception_event(event) + + # Kick off in background + asyncio.create_task(monitor_loop(monitor)) + """ + return await self._proc_mesh.monitor() + + @classmethod + def from_alloc(self, alloc: Alloc) -> Future["ProcMesh"]: + return Future( + lambda: _allocate_nonblocking(alloc), + lambda: _allocate_blocking(alloc), + ) + + def _spawn_blocking( + self, name: str, Class: Type[T], *args: Any, **kwargs: Any + ) -> T: + if not issubclass(Class, Actor): + raise ValueError( + f"{Class} must subclass monarch.service.Actor to spawn it." + ) + + actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor) + service = ActorMeshRef( + Class, + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + self._mailbox, + ) + # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by + # doing `ActorMeshRef(Class, actor_handle)` but not calling _create. + service._create(args, kwargs) + return cast(T, service) + + def __repr__(self) -> str: + return repr(self._proc_mesh) + + def __str__(self) -> str: + return str(self._proc_mesh) + + async def _spawn_nonblocking( + self, name: str, Class: Type[T], *args: Any, **kwargs: Any + ) -> T: + if not issubclass(Class, Actor): + raise ValueError( + f"{Class} must subclass monarch.service.Actor to spawn it." + ) + + actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor) + service = ActorMeshRef( + Class, + _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), + self._mailbox, + ) + # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by + # doing `ActorMeshRef(Class, actor_handle)` but not calling _create. + service._create(args, kwargs) + return cast(T, service) + + @property + def _device_mesh(self) -> "DeviceMesh": + if not has_tensor_engine(): + raise RuntimeError( + "DeviceMesh is not available because tensor_engine was not compiled (USE_TENSOR_ENGINE=0)" + ) + + # type: ignore[21] + from monarch.mesh_controller import spawn_tensor_engine # @manual + + if self._maybe_device_mesh is None: + if self._mock_shape is not None: + raise NotImplementedError( + "NYI: activating a proc mesh must first happen on the root proc_mesh until we fix spawning on submeshes." + ) + # type: ignore[21] + self._maybe_device_mesh = spawn_tensor_engine(self) + return self._maybe_device_mesh + + # pyre-ignore + def activate(self) -> AbstractContextManager: + return self._device_mesh.activate() + + def rank_tensor(self, dim: str | Sequence[str]) -> "Tensor": + return self._device_mesh.rank(dim) + + def rank_tensors(self) -> Dict[str, "Tensor"]: + return self._device_mesh.ranks + + async def sync_workspace(self, auto_reload: bool = False) -> None: + if self._rsync_mesh_client is None: + # TODO(agallagher): We need some way to configure and pass this + # in -- right now we're assuming the `gpu` dimension, which isn't + # correct. + assert set(self._proc_mesh.shape.labels).issubset({"gpus", "hosts"}) + # The workspace shape (i.e. only perform one rsync per host). + workspace_shape = self.slice(gpus=slice(0, 1, 1))._mock_shape + assert workspace_shape is not None + # TODO(agallagher): We should probably hide this behind something + # like a `Workspace` class and support abstracting/configuring + # different sync methods. + self._rsync_mesh_client = RsyncMeshClient.spawn_blocking( + proc_mesh=self._proc_mesh, + shape=workspace_shape, + # TODO(agallagher): Is there a better way to infer/set the local + # workspace dir, rather than use PWD? + local_workspace=os.getcwd(), + remote_workspace=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), + ) + self._auto_reload_actor = self._spawn_blocking( + "auto_reload", + AutoReloadActor, + WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), + ) + assert self._rsync_mesh_client is not None + await self._rsync_mesh_client.sync_workspace() + if auto_reload: + assert self._auto_reload_actor is not None + await self._auto_reload_actor.reload.call() + + async def stop(self) -> None: + await self._proc_mesh.stop() + self._stopped = True + + async def __aenter__(self) -> "ProcMesh": + if self._stopped: + raise RuntimeError("`ProcMesh` has already been stopped") + return self + + async def __aexit__( + self, exc_type: object, exc_val: object, exc_tb: object + ) -> None: + # In case there are multiple nested "async with" statements, we only + # want it to close once. + if not self._stopped: + await self.stop() + + # Finalizer to check if the proc mesh was closed properly. + def __del__(self) -> None: + if not self._stopped: + import warnings + + warnings.warn( + f"unstopped ProcMesh {self!r}", + ResourceWarning, + stacklevel=2, + source=self, + ) + # Cannot call stop here because it is async. + + +async def local_proc_mesh_nonblocking( + *, gpus: Optional[int] = None, hosts: int = 1 +) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = LocalAllocator() + alloc = await allocator.allocate(spec) + return await ProcMesh.from_alloc(alloc) + + +def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + allocator = LocalAllocator() + alloc = allocator.allocate(spec).get() + return ProcMesh.from_alloc(alloc).get() + + +def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: + return Future( + lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts), + lambda: local_proc_mesh_blocking(gpus=gpus, hosts=hosts), + ) + + +_BOOTSTRAP_MAIN = "monarch.bootstrap_main" + + +def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]: + if IN_PAR: + cmd = sys.argv[0] + args = None + env = { + "PAR_MAIN_OVERRIDE": _BOOTSTRAP_MAIN, + } + else: + cmd = sys.executable + args = ["-m", _BOOTSTRAP_MAIN] + env = {} + + return cmd, args, env + + +async def proc_mesh_nonblocking( + *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None +) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + env = env or {} + cmd, args, base_env = _get_bootstrap_args() + env.update(base_env) + allocator = ProcessAllocator(cmd, args, env) + alloc = await allocator.allocate(spec) + return await ProcMesh.from_alloc(alloc) + + +def proc_mesh_blocking( + *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None +) -> ProcMesh: + if gpus is None: + gpus = _local_device_count() + spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) + env = env or {} + cmd, args, base_env = _get_bootstrap_args() + env.update(base_env) + allocator = ProcessAllocator(cmd, args, env) + alloc = allocator.allocate(spec).get() + return ProcMesh.from_alloc(alloc).get() + + +def proc_mesh( + *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None +) -> Future[ProcMesh]: + return Future( + lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), + lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), + ) diff --git a/python/monarch/common/shape.py b/python/monarch/_src/actor/shape.py similarity index 100% rename from python/monarch/common/shape.py rename to python/monarch/_src/actor/shape.py diff --git a/python/monarch/telemetry.py b/python/monarch/_src/actor/telemetry/__init__.py similarity index 100% rename from python/monarch/telemetry.py rename to python/monarch/_src/actor/telemetry/__init__.py diff --git a/python/monarch/telemetry/rust_span_tracing.py b/python/monarch/_src/actor/telemetry/rust_span_tracing.py similarity index 100% rename from python/monarch/telemetry/rust_span_tracing.py rename to python/monarch/_src/actor/telemetry/rust_span_tracing.py diff --git a/python/monarch/_testing.py b/python/monarch/_testing.py index 67bd02f8..2e9a9475 100644 --- a/python/monarch/_testing.py +++ b/python/monarch/_testing.py @@ -13,10 +13,10 @@ from typing import Any, Callable, Dict, Generator, Literal, Optional import monarch_supervisor +from monarch._src.actor.shape import NDSlice from monarch.common.client import Client from monarch.common.device_mesh import DeviceMesh from monarch.common.invocation import DeviceException, RemoteException -from monarch.common.shape import NDSlice from monarch.controller.backend import ProcessBackend from monarch.mesh_controller import spawn_tensor_engine from monarch.proc_mesh import proc_mesh, ProcMesh diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py new file mode 100644 index 00000000..582a6f07 --- /dev/null +++ b/python/monarch/actor/__init__.py @@ -0,0 +1,38 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +""" +Monarch Actor API - Public interface for actor functionality. +""" + +from monarch._src.actor.actor_mesh import ( + Accumulator, + Actor, + ActorError, + current_actor_name, + current_rank, + current_size, + endpoint, + MonarchContext, + ValueMesh, +) +from monarch._src.actor.future import Future +from monarch._src.actor.proc_mesh import proc_mesh, ProcMesh + +__all__ = [ + "Accumulator", + "Actor", + "ActorError", + "current_actor_name", + "current_rank", + "current_size", + "endpoint", + "MonarchContext", + "ValueMesh", + "proc_mesh", + "ProcMesh", + "Future", +] diff --git a/python/monarch/actor_mesh.py b/python/monarch/actor_mesh.py index a7f00cff..2343f10f 100644 --- a/python/monarch/actor_mesh.py +++ b/python/monarch/actor_mesh.py @@ -4,815 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-unsafe +import warnings -import collections -import contextvars -import functools -import inspect -import io -import itertools -import logging -import random -import sys -import traceback -from contextlib import contextmanager - -from dataclasses import dataclass -from traceback import extract_tb, StackSummary -from typing import ( - Any, - AsyncGenerator, - Awaitable, - Callable, - cast, - Concatenate, - Dict, - Generic, - Iterable, - List, - Literal, - NamedTuple, - Optional, - ParamSpec, - Tuple, - Type, - TYPE_CHECKING, - TypeVar, -) - -import monarch - -import torch -from monarch import ActorFuture as Future -from monarch._rust_bindings.hyperactor_extension.telemetry import enter_span, exit_span - -from monarch._rust_bindings.monarch_hyperactor.actor import PanicFlag, PythonMessage -from monarch._rust_bindings.monarch_hyperactor.actor_mesh import PythonActorMesh -from monarch._rust_bindings.monarch_hyperactor.mailbox import ( - Mailbox, - OncePortReceiver, - OncePortRef, - PortReceiver as HyPortReceiver, - PortRef, +warnings.warn( + "monarch.actor_mesh is deprecated, please import from monarch.actor instead.", + DeprecationWarning, + stacklevel=2, ) -from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._rust_bindings.monarch_hyperactor.shape import Point as HyPoint, Shape - -from monarch.common.pickle_flatten import flatten, unflatten -from monarch.common.shape import MeshTrait, NDSlice -from monarch.pdb_wrapper import remote_breakpointhook - -if TYPE_CHECKING: - from monarch.debugger import DebugClient - -logger: logging.Logger = logging.getLogger(__name__) - -Allocator = monarch.ProcessAllocator | monarch.LocalAllocator - -try: - from __manifest__ import fbmake # noqa - - IN_PAR = True -except ImportError: - IN_PAR = False - -T1 = TypeVar("T1") -T2 = TypeVar("T2") - - -class Point(HyPoint, collections.abc.Mapping): - pass - - -@dataclass -class MonarchContext: - mailbox: Mailbox - proc_id: str - point: Point - - @staticmethod - def get() -> "MonarchContext": - return _context.get() - - -_context: contextvars.ContextVar[MonarchContext] = contextvars.ContextVar( - "monarch.actor_mesh._context" -) - - -T = TypeVar("T") -P = ParamSpec("P") -R = TypeVar("R") -A = TypeVar("A") - -# keep this load balancing deterministic, but -# equally distributed. -_load_balancing_seed = random.Random(4) - - -Selection = Literal["all", "choose"] # TODO: replace with real selection objects - - -# standin class for whatever is the serializable python object we use -# to name an actor mesh. Hacked up today because ActorMesh -# isn't plumbed to non-clients -class _ActorMeshRefImpl: - def __init__( - self, - mailbox: Mailbox, - hy_actor_mesh: Optional[PythonActorMesh], - shape: Shape, - actor_ids: List[ActorId], - ) -> None: - self._mailbox = mailbox - self._actor_mesh = hy_actor_mesh - self._shape = shape - self._please_replace_me_actor_ids = actor_ids - - @staticmethod - def from_hyperactor_mesh( - mailbox: Mailbox, hy_actor_mesh: PythonActorMesh - ) -> "_ActorMeshRefImpl": - shape: Shape = hy_actor_mesh.shape - return _ActorMeshRefImpl( - mailbox, - hy_actor_mesh, - hy_actor_mesh.shape, - [cast(ActorId, hy_actor_mesh.get(i)) for i in range(len(shape))], - ) - - @staticmethod - def from_actor_id(mailbox: Mailbox, actor_id: ActorId) -> "_ActorMeshRefImpl": - return _ActorMeshRefImpl(mailbox, None, singleton_shape, [actor_id]) - - @staticmethod - def from_actor_ref_with_shape( - ref: "_ActorMeshRefImpl", shape: Shape - ) -> "_ActorMeshRefImpl": - return _ActorMeshRefImpl( - ref._mailbox, None, shape, ref._please_replace_me_actor_ids - ) - - def __getstate__( - self, - ) -> Tuple[Shape, List[ActorId], Mailbox]: - return self._shape, self._please_replace_me_actor_ids, self._mailbox - - def __setstate__( - self, - state: Tuple[Shape, List[ActorId], Mailbox], - ) -> None: - self._actor_mesh = None - self._shape, self._please_replace_me_actor_ids, self._mailbox = state - - def send(self, rank: int, message: PythonMessage) -> None: - actor = self._please_replace_me_actor_ids[rank] - self._mailbox.post(actor, message) - - def cast( - self, - message: PythonMessage, - selection: Selection, - ) -> None: - # TODO: use the actual actor mesh when available. We cannot currently use it - # directly because we risk bifurcating the message delivery paths from the same - # client, since slicing the mesh will produce a reference, which calls actors - # directly. The reason these paths are bifurcated is that actor meshes will - # use multicasting, while direct actor comms do not. Separately we need to decide - # whether actor meshes are ordered with actor references. - # - # The fix is to provide a first-class reference into Python, and always call "cast" - # on it, including for load balanced requests. - if selection == "choose": - idx = _load_balancing_seed.randrange(len(self._shape)) - actor_rank = self._shape.ndslice[idx] - self._mailbox.post(self._please_replace_me_actor_ids[actor_rank], message) - return - elif selection == "all": - # replace me with actual remote actor mesh - call_shape = Shape( - self._shape.labels, NDSlice.new_row_major(self._shape.ndslice.sizes) - ) - for i, rank in enumerate(self._shape.ranks()): - self._mailbox.post_cast( - self._please_replace_me_actor_ids[rank], - i, - call_shape, - message, - ) - else: - raise ValueError(f"invalid selection: {selection}") - - def __len__(self) -> int: - return len(self._shape) - - -class Endpoint(Generic[P, R]): - def __init__( - self, - actor_mesh_ref: _ActorMeshRefImpl, - name: str, - impl: Callable[Concatenate[Any, P], Awaitable[R]], - mailbox: Mailbox, - ) -> None: - self._actor_mesh = actor_mesh_ref - self._name = name - self._signature: inspect.Signature = inspect.signature(impl) - self._mailbox = mailbox - - # the following are all 'adverbs' or different ways to handle the - # return values of this endpoint. Adverbs should only ever take *args, **kwargs - # of the original call. If we want to add syntax sugar for something that needs additional - # arguments, it should be implemented as function indepdendent of endpoint like `send` - # and `Accumulator` - def choose(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - """ - Load balanced sends a message to one chosen actor and awaits a result. - - Load balanced RPC-style entrypoint for request/response messaging. - """ - p: Port[R] - r: PortReceiver[R] - p, r = port(self, once=True) - # pyre-ignore - send(self, args, kwargs, port=p, selection="choose") - return r.recv() - - def call_one(self, *args: P.args, **kwargs: P.kwargs) -> Future[R]: - if len(self._actor_mesh) != 1: - raise ValueError( - f"Can only use 'call_one' on a single Actor but this actor has shape {self._actor_mesh._shape}" - ) - return self.choose(*args, **kwargs) - - def call(self, *args: P.args, **kwargs: P.kwargs) -> "Future[ValueMesh[R]]": - p: Port[R] - r: RankedPortReceiver[R] - p, r = ranked_port(self) - # pyre-ignore - send(self, args, kwargs, port=p) - - async def process() -> ValueMesh[R]: - results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] - for _ in range(len(self._actor_mesh)): - rank, value = await r.recv() - results[rank] = value - call_shape = Shape( - self._actor_mesh._shape.labels, - NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), - ) - return ValueMesh(call_shape, results) - - def process_blocking() -> ValueMesh[R]: - results: List[R] = [None] * len(self._actor_mesh) # pyre-fixme[9] - for _ in range(len(self._actor_mesh)): - rank, value = r.recv().get() - results[rank] = value - call_shape = Shape( - self._actor_mesh._shape.labels, - NDSlice.new_row_major(self._actor_mesh._shape.ndslice.sizes), - ) - return ValueMesh(call_shape, results) - - return Future(process, process_blocking) - - async def stream(self, *args: P.args, **kwargs: P.kwargs) -> AsyncGenerator[R, R]: - """ - Broadcasts to all actors and yields their responses as a stream / generator. - - This enables processing results from multiple actors incrementally as - they become available. Returns an async generator of response values. - """ - p, r = port(self) - # pyre-ignore - send(self, args, kwargs, port=p) - for _ in range(len(self._actor_mesh)): - yield await r.recv() - - def broadcast(self, *args: P.args, **kwargs: P.kwargs) -> None: - """ - Fire-and-forget broadcast to all actors without waiting for actors to - acknowledge receipt. - - In other words, the return of this method does not guarrantee the - delivery of the message. - """ - # pyre-ignore - send(self, args, kwargs) - - -class Accumulator(Generic[P, R, A]): - def __init__( - self, endpoint: Endpoint[P, R], identity: A, combine: Callable[[A, R], A] - ) -> None: - self._endpoint: Endpoint[P, R] = endpoint - self._identity: A = identity - self._combine: Callable[[A, R], A] = combine - - def accumulate(self, *args: P.args, **kwargs: P.kwargs) -> "Future[A]": - gen: AsyncGenerator[R, R] = self._endpoint.stream(*args, **kwargs) - - async def impl() -> A: - value = self._identity - async for x in gen: - value = self._combine(value, x) - return value - - return Future(impl) - - -class ValueMesh(MeshTrait, Generic[R]): - """ - Container of return values, indexed by rank. - """ - - def __init__(self, shape: Shape, values: List[R]) -> None: - self._shape = shape - self._values = values - - def _new_with_shape(self, shape: Shape) -> "ValueMesh[R]": - return ValueMesh(shape, self._values) - - def item(self, **kwargs) -> R: - coordinates = [kwargs.pop(label) for label in self._labels] - if kwargs: - raise KeyError(f"item has extra dimensions: {list(kwargs.keys())}") - - return self._values[self._ndslice.nditem(coordinates)] - - def __iter__(self): - for rank in self._shape.ranks(): - yield Point(rank, self._shape), self._values[rank] - - def __len__(self) -> int: - return len(self._shape) - - def __repr__(self) -> str: - return f"ValueMesh({self._shape})" - - @property - def _ndslice(self) -> NDSlice: - return self._shape.ndslice - - @property - def _labels(self) -> Iterable[str]: - return self._shape.labels - - -def send( - endpoint: Endpoint[P, R], - args: Tuple[Any, ...], - kwargs: Dict[str, Any], - port: "Optional[Port]" = None, - selection: Selection = "all", -) -> None: - """ - Fire-and-forget broadcast invocation of the endpoint across all actors in the mesh. - - This sends the message to all actors but does not wait for any result. - """ - endpoint._signature.bind(None, *args, **kwargs) - message = PythonMessage( - endpoint._name, - _pickle((args, kwargs)), - None if port is None else port._port_ref, - None, - ) - endpoint._actor_mesh.cast(message, selection) - - -class EndpointProperty(Generic[P, R]): - def __init__(self, method: Callable[Concatenate[Any, P], Awaitable[R]]) -> None: - self._method = method - - def __get__(self, instance, owner) -> Endpoint[P, R]: - # this is a total lie, but we have to actually - # recognize this was defined as an endpoint, - # and also lookup the method - return cast(Endpoint[P, R], self) - - -def endpoint( - method: Callable[Concatenate[Any, P], Awaitable[R]], -) -> EndpointProperty[P, R]: - return EndpointProperty(method) - - -class Port(Generic[R]): - def __init__( - self, port_ref: PortRef | OncePortRef, mailbox: Mailbox, rank: Optional[int] - ) -> None: - self._port_ref = port_ref - self._mailbox = mailbox - self._rank = rank - - def send(self, method: str, obj: R) -> None: - self._port_ref.send( - self._mailbox, - PythonMessage(method, _pickle(obj), None, self._rank), - ) - - -R = TypeVar("R") - -T = TypeVar("T") - -if TYPE_CHECKING: - # Python <= 3.10 cannot inherit from Generic[R] and NamedTuple at the same time. - # we only need it for type checking though, so copypasta it until 3.11. - class PortTuple(NamedTuple, Generic[R]): - sender: "Port[R]" - receiver: "PortReceiver[R]" - - @staticmethod - def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": - handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() - port_ref = handle.bind() - return PortTuple( - Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) - ) -else: - - class PortTuple(NamedTuple): - sender: "Port[Any]" - receiver: "PortReceiver[Any]" - - @staticmethod - def create(mailbox: Mailbox, once: bool = False) -> "PortTuple[Any]": - handle, receiver = mailbox.open_once_port() if once else mailbox.open_port() - port_ref = handle.bind() - return PortTuple( - Port(port_ref, mailbox, rank=None), PortReceiver(mailbox, receiver) - ) - - -# advance lower-level API for sending messages. This is intentially -# not part of the Endpoint API because they way it accepts arguments -# and handles concerns is different. -def port(endpoint: Endpoint[P, R], once: bool = False) -> "PortTuple[R]": - return PortTuple.create(endpoint._mailbox, once) - - -def ranked_port( - endpoint: Endpoint[P, R], once: bool = False -) -> Tuple["Port[R]", "RankedPortReceiver[R]"]: - p, receiver = port(endpoint, once) - return p, RankedPortReceiver[R](receiver._mailbox, receiver._receiver) - - -class PortReceiver(Generic[R]): - def __init__( - self, - mailbox: Mailbox, - receiver: HyPortReceiver | OncePortReceiver, - ) -> None: - self._mailbox: Mailbox = mailbox - self._receiver: HyPortReceiver | OncePortReceiver = receiver - - async def _recv(self) -> R: - return self._process(await self._receiver.recv()) - - def _blocking_recv(self) -> R: - return self._process(self._receiver.blocking_recv()) - - def _process(self, msg: PythonMessage) -> R: - # TODO: Try to do something more structured than a cast here - payload = cast(R, _unpickle(msg.message, self._mailbox)) - if msg.method == "result": - return payload - else: - assert msg.method == "exception" - # pyre-ignore - raise payload - - def recv(self) -> "Future[R]": - return Future(lambda: self._recv(), self._blocking_recv) - - -class RankedPortReceiver(PortReceiver[Tuple[int, R]]): - def _process(self, msg: PythonMessage) -> Tuple[int, R]: - if msg.rank is None: - raise ValueError("RankedPort receiver got a message without a rank") - return msg.rank, super()._process(msg) - - -singleton_shape = Shape([], NDSlice(offset=0, sizes=[], strides=[])) - - -class _Actor: - """ - This is the message handling implementation of a Python actor. - - The layering goes: - Rust `PythonActor` -> `_Actor` -> user-provided `Actor` instance - - Messages are received from the Rust backend, and forwarded to the `handle` - methods on this class. - - This class wraps the actual `Actor` instance provided by the user, and - routes messages to it, managing argument serialization/deserialization and - error handling. - """ - - def __init__(self) -> None: - self.instance: object | None = None - - async def handle( - self, mailbox: Mailbox, message: PythonMessage, panic_flag: PanicFlag - ) -> None: - return await self.handle_cast(mailbox, 0, singleton_shape, message, panic_flag) - - async def handle_cast( - self, - mailbox: Mailbox, - rank: int, - shape: Shape, - message: PythonMessage, - panic_flag: PanicFlag, - ) -> None: - port = ( - Port(message.response_port, mailbox, rank) - if message.response_port - else None - ) - try: - ctx: MonarchContext = MonarchContext( - mailbox, mailbox.actor_id.proc_id, Point(rank, shape) - ) - _context.set(ctx) - - args, kwargs = _unpickle(message.message, mailbox) - - if message.method == "__init__": - Class, *args = args - self.instance = Class(*args, **kwargs) - return None - - if self.instance is None: - # This could happen because of the following reasons. Both - # indicates a possible bug in the framework: - # 1. the execution of the previous message for "__init__" failed, - # but that error is not surfaced to the caller. - # - TODO(T229200522): there is a known bug. fix it. - # 2. this message is delivered to this actor before the previous - # message of "__init__" is delivered. Out-of-order delivery - # should never happen. It indicates either a bug in the - # message delivery mechanism, or the framework accidentally - # mixed the usage of cast and direct send. - raise AssertionError( - f""" - actor object is missing when executing method {message.method} - on actor {mailbox.actor_id} - """ - ) - the_method = getattr(self.instance, message.method)._method - - if inspect.iscoroutinefunction(the_method): - - async def instrumented(): - enter_span( - the_method.__module__, - message.method, - str(ctx.mailbox.actor_id), - ) - try: - result = await the_method(self.instance, *args, **kwargs) - except Exception as e: - logging.critical( - "Unahndled exception in actor endpoint", - exc_info=e, - ) - raise e - exit_span() - return result - - result = await instrumented() - else: - enter_span( - the_method.__module__, message.method, str(ctx.mailbox.actor_id) - ) - result = the_method(self.instance, *args, **kwargs) - exit_span() - - if port is not None: - port.send("result", result) - except Exception as e: - traceback.print_exc() - s = ActorError(e) - - # The exception is delivered to exactly one of: - # (1) our caller, (2) our supervisor - if port is not None: - port.send("exception", s) - else: - raise s from None - except BaseException as e: - # A BaseException can be thrown in the case of a Rust panic. - # In this case, we need a way to signal the panic to the Rust side. - # See [Panics in async endpoints] - try: - panic_flag.signal_panic(e) - except Exception: - # The channel might be closed if the Rust side has already detected the error - pass - raise - - -def _is_mailbox(x: object) -> bool: - return isinstance(x, Mailbox) - - -def _pickle(obj: object) -> bytes: - _, msg = flatten(obj, _is_mailbox) - return msg - - -@contextmanager -def _load_tensors_on_cpu(): - # Ensure that any tensors load from CPU via monkeypatching how Storages are - # loaded. - old = torch.storage._load_from_bytes - try: - torch.storage._load_from_bytes = lambda b: torch.load( - io.BytesIO(b), map_location="cpu", weights_only=False - ) - yield - finally: - torch.storage._load_from_bytes = old - - -def _unpickle(data: bytes, mailbox: Mailbox) -> Any: - with _load_tensors_on_cpu(): - # regardless of the mailboxes of the remote objects - # they all become the local mailbox. - return unflatten(data, itertools.repeat(mailbox)) - - -class Actor(MeshTrait): - @functools.cached_property - def logger(cls) -> logging.Logger: - lgr = logging.getLogger(cls.__class__.__name__) - lgr.setLevel(logging.DEBUG) - return lgr - - @property - def _ndslice(self) -> NDSlice: - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - @property - def _labels(self) -> Tuple[str, ...]: - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": - raise NotImplementedError( - "actor implementations are not meshes, but we can't convince the typechecker of it..." - ) - - @endpoint # pyre-ignore - def _set_debug_client(self, client: "DebugClient") -> None: - point = MonarchContext.get().point - # For some reason, using a lambda instead of functools.partial - # confuses the pdb wrapper implementation. - sys.breakpointhook = functools.partial( # pyre-ignore - remote_breakpointhook, - point.rank, - point.shape.coordinates(point.rank), - MonarchContext.get().mailbox.actor_id, - client, - ) - - -class ActorMeshRef(MeshTrait, Generic[T]): - def __init__( - self, Class: Type[T], actor_mesh_ref: _ActorMeshRefImpl, mailbox: Mailbox - ) -> None: - self.__name__: str = Class.__name__ - self._class: Type[T] = Class - self._actor_mesh_ref: _ActorMeshRefImpl = actor_mesh_ref - self._mailbox: Mailbox = mailbox - for attr_name in dir(self._class): - attr_value = getattr(self._class, attr_name, None) - if isinstance(attr_value, EndpointProperty): - setattr( - self, - attr_name, - Endpoint( - self._actor_mesh_ref, - attr_name, - attr_value._method, - self._mailbox, - ), - ) - - def __getattr__(self, name: str) -> Any: - # This method is called when an attribute is not found - # For linting purposes, we need to tell the type checker that any attribute - # could be an endpoint that's dynamically added at runtime - # At runtime, we still want to raise AttributeError for truly missing attributes - - # Check if this is a method on the underlying class - if hasattr(self._class, name): - attr = getattr(self._class, name) - if isinstance(attr, EndpointProperty): - # Dynamically create the endpoint - endpoint = Endpoint( - self._actor_mesh_ref, - name, - attr._method, - self._mailbox, - ) - # Cache it for future use - setattr(self, name, endpoint) - return endpoint - - # If we get here, it's truly not found - raise AttributeError( - f"'{self.__class__.__name__}' object has no attribute '{name}'" - ) - - def _create( - self, - args: Iterable[Any], - kwargs: Dict[str, Any], - ) -> None: - async def null_func(*_args: Iterable[Any], **_kwargs: Dict[str, Any]) -> None: - return None - - ep = Endpoint( - self._actor_mesh_ref, - "__init__", - null_func, - self._mailbox, - ) - send(ep, (self._class, *args), kwargs) - - def __reduce_ex__( - self, protocol: ... - ) -> "Tuple[Type[ActorMeshRef], Tuple[Any, ...]]": - return ActorMeshRef, ( - self._class, - self._actor_mesh_ref, - self._mailbox, - ) - - @property - def _ndslice(self) -> NDSlice: - return self._actor_mesh_ref._shape.ndslice - - @property - def _labels(self) -> Iterable[str]: - return self._actor_mesh_ref._shape.labels - - def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": - return ActorMeshRef( - self._class, - _ActorMeshRefImpl.from_actor_ref_with_shape(self._actor_mesh_ref, shape), - self._mailbox, - ) - - def __repr__(self) -> str: - return f"ActorMeshRef(class={self._class}, shape={self._actor_mesh_ref._shape})" - - -class ActorError(Exception): - """ - Deterministic problem with the user's code. - For example, an OOM resulting in trying to allocate too much GPU memory, or violating - some invariant enforced by the various APIs. - """ - - def __init__( - self, - exception: Exception, - message: str = "A remote actor call has failed asynchronously.", - ) -> None: - self.exception = exception - self.actor_mesh_ref_frames: StackSummary = extract_tb(exception.__traceback__) - self.message = message - - def __str__(self) -> str: - exe = str(self.exception) - actor_mesh_ref_tb = "".join(traceback.format_list(self.actor_mesh_ref_frames)) - return ( - f"{self.message}\n" - f"Traceback of where the remote call failed (most recent call last):\n{actor_mesh_ref_tb}{type(self.exception).__name__}: {exe}" - ) - - -def current_actor_name() -> str: - return str(MonarchContext.get().mailbox.actor_id) - - -def current_rank() -> Point: - ctx = MonarchContext.get() - return ctx.point - -def current_size() -> Dict[str, int]: - ctx = MonarchContext.get() - return dict(zip(ctx.point.shape.labels, ctx.point.shape.ndslice.sizes)) +from monarch._src.actor.actor_mesh import * # noqa diff --git a/python/monarch/bootstrap_main.py b/python/monarch/bootstrap_main.py index 77befe9b..5fa77b45 100644 --- a/python/monarch/bootstrap_main.py +++ b/python/monarch/bootstrap_main.py @@ -4,56 +4,17 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -""" -This is the main function for the boostrapping a new process using a ProcessAllocator. -""" +import warnings -import asyncio -import importlib.resources -import logging -import os -import sys +warnings.warn( + "monarch.bootstrap_main is deprecated, please use from monarch._src.actor.bootstrap_main instead.", + DeprecationWarning, + stacklevel=2, +) -# Import torch to avoid import-time races if a spawned actor tries to import torch. -import torch # noqa[F401] - - -async def main(): - from monarch._rust_bindings.monarch_hyperactor.bootstrap import bootstrap_main - - await bootstrap_main() - - -def invoke_main(): - # if this is invoked with the stdout piped somewhere, then print - # changes its buffering behavior. So we default to the standard - # behavior of std out as if it were a terminal. - sys.stdout.reconfigure(line_buffering=True) - global bootstrap_main - - # TODO: figure out what from worker_main.py we should reproduce here. - from monarch.telemetry import TracingForwarder - - if os.environ.get("MONARCH_ERROR_DURING_BOOTSTRAP_FOR_TESTING") == "1": - raise RuntimeError("Error during bootstrap for testing") - - # forward logs to rust tracing. Defaults to on. - if os.environ.get("MONARCH_PYTHON_LOG_TRACING", "1") == "1": - logging.root.addHandler(TracingForwarder(level=logging.DEBUG)) - - try: - with ( - importlib.resources.path("monarch", "py-spy") as pyspy, - ): - if pyspy.exists(): - os.environ["PYSPY_BIN"] = str(pyspy) - # fallback to using local py-spy - except Exception as e: - logging.warning(f"Failed to set up py-spy: {e}") - - # Start an event loop for PythonActors to use. - asyncio.run(main()) +from monarch._src.actor.bootstrap_main import * # noqa if __name__ == "__main__": + # noqa invoke_main() # pragma: no cover diff --git a/python/monarch/common/client.py b/python/monarch/common/client.py index bfc6ca73..c9bf546f 100644 --- a/python/monarch/common/client.py +++ b/python/monarch/common/client.py @@ -37,6 +37,7 @@ LogLevel, WorldState, ) +from monarch._src.actor.shape import NDSlice from monarch.common import messages from monarch.common.borrows import Borrow, StorageAliases from monarch.common.controller_api import LogMessage, MessageResult, TController @@ -47,7 +48,6 @@ from monarch.common.recording import flatten_messages, Recording from monarch.common.reference import Ref, Referenceable -from monarch.common.shape import NDSlice from monarch.common.stream import StreamRef from monarch.common.tensor import Tensor from monarch.common.tree import tree_map diff --git a/python/monarch/common/controller_api.py b/python/monarch/common/controller_api.py index 2baa271c..b0a42e84 100644 --- a/python/monarch/common/controller_api.py +++ b/python/monarch/common/controller_api.py @@ -13,9 +13,10 @@ WorldState, ) +from monarch._src.actor.shape import NDSlice + from monarch.common.invocation import DeviceException, RemoteException, Seq from monarch.common.reference import Ref -from monarch.common.shape import NDSlice from monarch.common.tensor import Tensor diff --git a/python/monarch/common/device_mesh.py b/python/monarch/common/device_mesh.py index da02da05..86efc9e2 100644 --- a/python/monarch/common/device_mesh.py +++ b/python/monarch/common/device_mesh.py @@ -28,7 +28,7 @@ import monarch.common.messages as messages import torch -from monarch.common.shape import MeshTrait +from monarch._src.actor.shape import MeshTrait, NDSlice, Shape from torch.utils._python_dispatch import TorchDispatchMode from torch.utils._pytree import tree_map @@ -37,7 +37,6 @@ from .context_manager import activate_first_context_manager from .messages import Dims from .reference import Referenceable -from .shape import NDSlice, Shape from .stream import Stream from .tensor import MeshSliceTensor, Tensor diff --git a/python/monarch/common/messages.py b/python/monarch/common/messages.py index 5cfe6604..5e600912 100644 --- a/python/monarch/common/messages.py +++ b/python/monarch/common/messages.py @@ -22,13 +22,14 @@ ) from monarch._rust_bindings.monarch_extension import tensor_worker + +from monarch._src.actor.shape import NDSlice from monarch.common.function import ResolvableFromCloudpickle, ResolvableFunction from monarch.common.invocation import DeviceException, RemoteException from monarch.common.reference import Referenceable from monarch.common.tree import flattener from pyre_extensions import none_throws -from .shape import NDSlice from .tensor_factory import TensorFactory if TYPE_CHECKING: diff --git a/python/monarch/common/recording.py b/python/monarch/common/recording.py index b1a9e68d..4417ad90 100644 --- a/python/monarch/common/recording.py +++ b/python/monarch/common/recording.py @@ -10,9 +10,9 @@ from collections import defaultdict from typing import cast, Dict, Generator, List, NamedTuple, Tuple, TYPE_CHECKING, Union -from monarch.common.reference import Ref +from monarch._src.actor.shape import iter_ranks -from monarch.common.shape import iter_ranks +from monarch.common.reference import Ref from monarch.common.tensor import InputChecker @@ -21,8 +21,9 @@ if TYPE_CHECKING: from monarch.common.client import Client +from monarch._src.actor.shape import NDSlice + from .reference import Referenceable -from .shape import NDSlice from .tensor import Tensor logger = logging.getLogger(__name__) diff --git a/python/monarch/common/tensor.py b/python/monarch/common/tensor.py index c379f4af..29ebd6b9 100644 --- a/python/monarch/common/tensor.py +++ b/python/monarch/common/tensor.py @@ -40,12 +40,13 @@ if TYPE_CHECKING: from monarch.common.device_mesh import DeviceMesh +from monarch._src.actor.shape import NDSlice + from .fake import fake_call from .function import Propagator, ResolvableFunction from .invocation import Invocation from .messages import Dims from .reference import Referenceable -from .shape import NDSlice from .stream import Stream from .tree import flatten diff --git a/python/monarch/controller/backend.py b/python/monarch/controller/backend.py index 1b4e02b1..09b577cc 100644 --- a/python/monarch/controller/backend.py +++ b/python/monarch/controller/backend.py @@ -13,9 +13,9 @@ from abc import ABC, abstractmethod from typing import List, NamedTuple, Optional, Sequence, Tuple -from monarch.common import messages +from monarch._src.actor.shape import iter_ranks, Slices as Ranks -from monarch.common.shape import iter_ranks, Slices as Ranks +from monarch.common import messages from monarch_supervisor import ( Context, FunctionCall, diff --git a/python/monarch/controller/controller.py b/python/monarch/controller/controller.py index 1d669b35..d96b2ac5 100644 --- a/python/monarch/controller/controller.py +++ b/python/monarch/controller/controller.py @@ -19,11 +19,12 @@ ActorId, ) +from monarch._src.actor.shape import NDSlice + from monarch.common import messages from monarch.common.controller_api import LogMessage, MessageResult from monarch.common.invocation import DeviceException, Seq from monarch.common.reference import Ref -from monarch.common.shape import NDSlice from monarch.common.tensor import Tensor from monarch.controller import debugger diff --git a/python/monarch/controller/rust_backend/controller.py b/python/monarch/controller/rust_backend/controller.py index f29189ba..f6dfecf4 100644 --- a/python/monarch/controller/rust_backend/controller.py +++ b/python/monarch/controller/rust_backend/controller.py @@ -29,11 +29,12 @@ ) from monarch._rust_bindings.monarch_messages.debugger import DebuggerAction + +from monarch._src.actor.shape import NDSlice from monarch.common.controller_api import LogMessage, MessageResult from monarch.common.device_mesh import no_mesh from monarch.common.invocation import DeviceException, RemoteException from monarch.common.messages import SupportsToRustMessage -from monarch.common.shape import NDSlice from monarch.common.tensor import Tensor from monarch.controller.debugger import read as debugger_read, write as debugger_write from pyre_extensions import none_throws diff --git a/python/monarch/mesh_controller.py b/python/monarch/mesh_controller.py index a21f3b76..78c7739b 100644 --- a/python/monarch/mesh_controller.py +++ b/python/monarch/mesh_controller.py @@ -34,11 +34,11 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) +from monarch._src.actor.shape import NDSlice from monarch.actor_mesh import Port, PortTuple from monarch.common import messages from monarch.common.controller_api import TController from monarch.common.invocation import Seq -from monarch.common.shape import NDSlice from monarch.common.stream import StreamRef from monarch.common.tensor import Tensor @@ -48,7 +48,7 @@ from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( ProcMesh as HyProcMesh, ) - from monarch.proc_mesh import ProcMesh + from monarch.actor import ProcMesh from monarch._rust_bindings.monarch_hyperactor.shape import Point diff --git a/python/monarch/proc_mesh.py b/python/monarch/proc_mesh.py index ddf1957c..5523a2bf 100644 --- a/python/monarch/proc_mesh.py +++ b/python/monarch/proc_mesh.py @@ -4,356 +4,12 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -# pyre-strict +import warnings -import os -import sys -from contextlib import AbstractContextManager - -from typing import ( - Any, - cast, - Dict, - List, - Optional, - Sequence, - Type, - TYPE_CHECKING, - TypeVar, -) - -if TYPE_CHECKING: - import torch - -import monarch -from monarch import ActorFuture as Future - -# Conditionally import DeviceMesh and spawn_tensor_engine only if tensor_engine is available -# pyre-ignore[21] -from monarch._rust_bindings import has_tensor_engine - -from monarch._rust_bindings.hyperactor_extension.alloc import ( # @manual=//monarch/monarch_extension:monarch_extension # @manual=//monarch/monarch_extension:monarch_extension - Alloc, - AllocConstraints, - AllocSpec, +warnings.warn( + "monarch.proc_mesh is deprecated, please import from monarch.actor instead.", + DeprecationWarning, + stacklevel=2, ) -from monarch._rust_bindings.monarch_hyperactor.mailbox import Mailbox -from monarch._rust_bindings.monarch_hyperactor.proc_mesh import ( - ProcMesh as HyProcMesh, - ProcMeshMonitor, -) -from monarch._rust_bindings.monarch_hyperactor.shape import Shape, Slice -from monarch.actor_mesh import _Actor, _ActorMeshRefImpl, Actor, ActorMeshRef -from monarch.code_sync import RsyncMeshClient, WorkspaceLocation -from monarch.code_sync.auto_reload import AutoReloadActor -from monarch.common._device_utils import _local_device_count -from monarch.common.shape import MeshTrait -from monarch.rdma import RDMAManager - -if has_tensor_engine(): - from monarch.common.device_mesh import DeviceMesh - from monarch.mesh_controller import spawn_tensor_engine -else: - DeviceMesh = None - spawn_tensor_engine = None - -T = TypeVar("T") -try: - from __manifest__ import fbmake # noqa - - IN_PAR = True -except ImportError: - IN_PAR = False - - -async def _allocate_nonblocking(alloc: Alloc) -> "ProcMesh": - return ProcMesh(await HyProcMesh.allocate_nonblocking(alloc)) - - -def _allocate_blocking(alloc: Alloc) -> "ProcMesh": - return ProcMesh(HyProcMesh.allocate_blocking(alloc)) - - -class ProcMesh(MeshTrait): - def __init__( - self, - hy_proc_mesh: HyProcMesh, - _mock_shape: Optional[Shape] = None, - _device_mesh: Optional[DeviceMesh] = None, - ) -> None: - self._proc_mesh = hy_proc_mesh - self._mock_shape: Optional[Shape] = _mock_shape - self._mailbox: Mailbox = self._proc_mesh.client - self._rdma_manager: Optional[RDMAManager] = None - self._rsync_mesh_client: Optional[RsyncMeshClient] = None - self._auto_reload_actor: Optional[AutoReloadActor] = None - self._maybe_device_mesh: Optional[DeviceMesh] = _device_mesh - self._stopped = False - if _mock_shape is None: - self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) - - @property - def _shape(self) -> Shape: - return self._proc_mesh.shape if self._mock_shape is None else self._mock_shape - - @property - def _ndslice(self) -> Slice: - return self._shape.ndslice - - @property - def _labels(self) -> List[str]: - return self._shape.labels - - def _new_with_shape(self, shape: Shape) -> "ProcMesh": - device_mesh = ( - None - if self._device_mesh is None - else self._device_mesh._new_with_shape(shape) - ) - return ProcMesh(self._proc_mesh, _mock_shape=shape, _device_mesh=device_mesh) - - def spawn( - self, name: str, Class: Type[T], *args: Any, **kwargs: Any - ) -> Future[ActorMeshRef[T]]: - if self._mock_shape is not None: - raise NotImplementedError("NYI: spawn on slice of a proc mesh.") - return Future( - lambda: self._spawn_nonblocking(name, Class, *args, **kwargs), - lambda: self._spawn_blocking(name, Class, *args, **kwargs), - ) - - async def monitor(self) -> ProcMeshMonitor: - """ - Get a monitor (async iterator) of the proc mesh, it is used to - monitor the status of the proc mesh. This function can be called at most once. - - Note: This API is experimental and subject to change. - - Example: - - async def monitor_loop(monitor): - async for event in monitor: - await handle_exception_event(event) - - # Kick off in background - asyncio.create_task(monitor_loop(monitor)) - """ - return await self._proc_mesh.monitor() - - @classmethod - def from_alloc(self, alloc: Alloc) -> Future["ProcMesh"]: - return Future( - lambda: _allocate_nonblocking(alloc), - lambda: _allocate_blocking(alloc), - ) - - def _spawn_blocking( - self, name: str, Class: Type[T], *args: Any, **kwargs: Any - ) -> T: - if not issubclass(Class, Actor): - raise ValueError( - f"{Class} must subclass monarch.service.Actor to spawn it." - ) - - actor_mesh = self._proc_mesh.spawn_blocking(name, _Actor) - service = ActorMeshRef( - Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), - self._mailbox, - ) - # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by - # doing `ActorMeshRef(Class, actor_handle)` but not calling _create. - service._create(args, kwargs) - return cast(T, service) - - def __repr__(self) -> str: - return repr(self._proc_mesh) - - def __str__(self) -> str: - return str(self._proc_mesh) - - async def _spawn_nonblocking( - self, name: str, Class: Type[T], *args: Any, **kwargs: Any - ) -> T: - if not issubclass(Class, Actor): - raise ValueError( - f"{Class} must subclass monarch.service.Actor to spawn it." - ) - - actor_mesh = await self._proc_mesh.spawn_nonblocking(name, _Actor) - service = ActorMeshRef( - Class, - _ActorMeshRefImpl.from_hyperactor_mesh(self._mailbox, actor_mesh), - self._mailbox, - ) - # useful to have this separate, because eventually we can reconstitute ActorMeshRef objects across pickling by - # doing `ActorMeshRef(Class, actor_handle)` but not calling _create. - service._create(args, kwargs) - return cast(T, service) - - @property - def _device_mesh(self) -> "DeviceMesh": - if spawn_tensor_engine is None: - raise RuntimeError( - "DeviceMesh is not available because tensor_engine was not compiled (USE_TENSOR_ENGINE=0)" - ) - if self._maybe_device_mesh is None: - if self._mock_shape is not None: - raise NotImplementedError( - "NYI: activating a proc mesh must first happen on the root proc_mesh until we fix spawning on submeshes." - ) - self._maybe_device_mesh = spawn_tensor_engine(self) - return self._maybe_device_mesh - - # pyre-ignore - def activate(self) -> AbstractContextManager: - return self._device_mesh.activate() - - def rank_tensor(self, dim: str | Sequence[str]) -> "torch.Tensor": - return self._device_mesh.rank(dim) - - def rank_tensors(self) -> Dict[str, "torch.Tensor"]: - return self._device_mesh.ranks - - async def sync_workspace(self, auto_reload: bool = False) -> None: - if self._rsync_mesh_client is None: - # TODO(agallagher): We need some way to configure and pass this - # in -- right now we're assuming the `gpu` dimension, which isn't - # correct. - assert set(self._proc_mesh.shape.labels).issubset({"gpus", "hosts"}) - # The workspace shape (i.e. only perform one rsync per host). - workspace_shape = self.slice(gpus=slice(0, 1, 1))._mock_shape - assert workspace_shape is not None - # TODO(agallagher): We should probably hide this behind something - # like a `Workspace` class and support abstracting/configuring - # different sync methods. - self._rsync_mesh_client = RsyncMeshClient.spawn_blocking( - proc_mesh=self._proc_mesh, - shape=workspace_shape, - # TODO(agallagher): Is there a better way to infer/set the local - # workspace dir, rather than use PWD? - local_workspace=os.getcwd(), - remote_workspace=WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), - ) - self._auto_reload_actor = self._spawn_blocking( - "auto_reload", - AutoReloadActor, - WorkspaceLocation.FromEnvVar("WORKSPACE_DIR"), - ) - assert self._rsync_mesh_client is not None - await self._rsync_mesh_client.sync_workspace() - if auto_reload: - assert self._auto_reload_actor is not None - await self._auto_reload_actor.reload.call() - - async def stop(self) -> None: - await self._proc_mesh.stop() - self._stopped = True - - async def __aenter__(self) -> "ProcMesh": - if self._stopped: - raise RuntimeError("`ProcMesh` has already been stopped") - return self - - async def __aexit__( - self, exc_type: object, exc_val: object, exc_tb: object - ) -> None: - # In case there are multiple nested "async with" statements, we only - # want it to close once. - if not self._stopped: - await self.stop() - - # Finalizer to check if the proc mesh was closed properly. - def __del__(self) -> None: - if not self._stopped: - import warnings - - warnings.warn( - f"unstopped ProcMesh {self!r}", - ResourceWarning, - stacklevel=2, - source=self, - ) - # Cannot call stop here because it is async. - - -async def local_proc_mesh_nonblocking( - *, gpus: Optional[int] = None, hosts: int = 1 -) -> ProcMesh: - if gpus is None: - gpus = _local_device_count() - spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) - allocator = monarch.LocalAllocator() - alloc = await allocator.allocate(spec) - return await ProcMesh.from_alloc(alloc) - - -def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: - if gpus is None: - gpus = _local_device_count() - spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) - allocator = monarch.LocalAllocator() - alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() - - -def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: - return Future( - lambda: local_proc_mesh_nonblocking(gpus=gpus, hosts=hosts), - lambda: local_proc_mesh_blocking(gpus=gpus, hosts=hosts), - ) - - -_BOOTSTRAP_MAIN = "monarch.bootstrap_main" - - -def _get_bootstrap_args() -> tuple[str, Optional[list[str]], dict[str, str]]: - if IN_PAR: - cmd = sys.argv[0] - args = None - env = { - "PAR_MAIN_OVERRIDE": _BOOTSTRAP_MAIN, - } - else: - cmd = sys.executable - args = ["-m", _BOOTSTRAP_MAIN] - env = {} - - return cmd, args, env - - -async def proc_mesh_nonblocking( - *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None -) -> ProcMesh: - if gpus is None: - gpus = _local_device_count() - spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) - env = env or {} - cmd, args, base_env = _get_bootstrap_args() - env.update(base_env) - allocator = monarch.ProcessAllocator(cmd, args, env) - alloc = await allocator.allocate(spec) - return await ProcMesh.from_alloc(alloc) - - -def proc_mesh_blocking( - *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None -) -> ProcMesh: - if gpus is None: - gpus = _local_device_count() - spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) - env = env or {} - cmd, args, base_env = _get_bootstrap_args() - env.update(base_env) - allocator = monarch.ProcessAllocator(cmd, args, env) - alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() - -def proc_mesh( - *, gpus: Optional[int] = None, hosts: int = 1, env: Optional[dict[str, str]] = None -) -> Future[ProcMesh]: - return Future( - lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), - lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), - ) +from monarch._src.actor.proc_mesh import * # noqa diff --git a/python/monarch/python_local_mesh.py b/python/monarch/python_local_mesh.py index 39c56709..c3de5980 100644 --- a/python/monarch/python_local_mesh.py +++ b/python/monarch/python_local_mesh.py @@ -11,7 +11,7 @@ from typing import Optional, TYPE_CHECKING import monarch_supervisor -from monarch.common._device_utils import _local_device_count +from monarch._src.actor.device_utils import _local_device_count from monarch.common.fake import fake_call from monarch.common.invocation import DeviceException, RemoteException from monarch.world_mesh import world_mesh diff --git a/python/monarch/rdma.py b/python/monarch/rdma.py index 3ddbe855..b9cc771a 100644 --- a/python/monarch/rdma.py +++ b/python/monarch/rdma.py @@ -12,8 +12,7 @@ import torch from monarch._rust_bindings.monarch_hyperactor.proc import ActorId - -from monarch.actor_mesh import ( +from monarch._src.actor.actor_mesh import ( _ActorMeshRefImpl, Actor, ActorMeshRef, diff --git a/python/monarch/rust_backend_mesh.py b/python/monarch/rust_backend_mesh.py index 14ae325f..665ccb2b 100644 --- a/python/monarch/rust_backend_mesh.py +++ b/python/monarch/rust_backend_mesh.py @@ -20,11 +20,12 @@ init_proc, Proc, ) + +from monarch._src.actor.shape import NDSlice from monarch.common.client import Client from monarch.common.device_mesh import DeviceMesh, DeviceMeshStatus from monarch.common.invocation import DeviceException, RemoteException from monarch.common.mast import MastJob -from monarch.common.shape import NDSlice from monarch.controller.rust_backend.controller import RustController TORCHX_MAST_TASK_GROUP_NAME = "script" diff --git a/python/monarch/sim_mesh.py b/python/monarch/sim_mesh.py index 008cc165..b91361bf 100644 --- a/python/monarch/sim_mesh.py +++ b/python/monarch/sim_mesh.py @@ -40,6 +40,8 @@ init_proc, Proc, ) + +from monarch._src.actor.shape import NDSlice from monarch.common.client import Client from monarch.common.constants import ( SIM_MESH_CLIENT_SUPERVISION_UPDATE_INTERVAL, @@ -50,7 +52,6 @@ from monarch.common.future import Future, T from monarch.common.invocation import DeviceException, RemoteException from monarch.common.messages import Dims -from monarch.common.shape import NDSlice from monarch.controller.rust_backend.controller import RustController from monarch.rust_backend_mesh import MeshWorld diff --git a/python/monarch/simulator/command_history.py b/python/monarch/simulator/command_history.py index 42d53bfe..30755579 100644 --- a/python/monarch/simulator/command_history.py +++ b/python/monarch/simulator/command_history.py @@ -12,9 +12,9 @@ from typing import List, NamedTuple, Optional, Sequence import torch +from monarch._src.actor.shape import NDSlice from monarch.common import messages -from monarch.common.shape import NDSlice from monarch.simulator.ir import IRGraph from monarch.simulator.tensor import DTensorRef from monarch.simulator.utils import clean_name, file_path_with_iter diff --git a/python/monarch/simulator/interface.py b/python/monarch/simulator/interface.py index 59420c71..96b779b3 100644 --- a/python/monarch/simulator/interface.py +++ b/python/monarch/simulator/interface.py @@ -6,9 +6,10 @@ from typing import Union +from monarch._src.actor.shape import NDSlice + from monarch.common.client import Client as _Client from monarch.common.device_mesh import DeviceMesh -from monarch.common.shape import NDSlice from monarch.simulator.ir import IRGraph from monarch.simulator.simulator import ( diff --git a/python/monarch/simulator/mock_controller.py b/python/monarch/simulator/mock_controller.py index 4a9d9fea..605ac58b 100644 --- a/python/monarch/simulator/mock_controller.py +++ b/python/monarch/simulator/mock_controller.py @@ -25,6 +25,7 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) +from monarch._src.actor.shape import iter_ranks, NDSlice, Slices as Ranks from monarch.common import messages @@ -32,7 +33,6 @@ from monarch.common.device_mesh import no_mesh from monarch.common.invocation import Invocation, RemoteException, Seq from monarch.common.reference import Ref -from monarch.common.shape import iter_ranks, NDSlice, Slices as Ranks from monarch.common.tree import flatten if TYPE_CHECKING: diff --git a/python/monarch/simulator/simulator.py b/python/monarch/simulator/simulator.py index 5e4083b9..9a7c79c9 100644 --- a/python/monarch/simulator/simulator.py +++ b/python/monarch/simulator/simulator.py @@ -43,12 +43,12 @@ from monarch._rust_bindings.monarch_hyperactor.proc import ( # @manual=//monarch/monarch_extension:monarch_extension ActorId, ) +from monarch._src.actor.shape import iter_ranks, NDSlice from monarch.common import messages from monarch.common.controller_api import LogMessage, MessageResult from monarch.common.device_mesh import DeviceMesh from monarch.common.function import ResolvableFunction, ResolvableFunctionFromPath from monarch.common.invocation import DeviceException -from monarch.common.shape import iter_ranks, NDSlice from monarch.simulator.command_history import CommandHistory, DTensorRef from monarch.simulator.config import META_VAL from monarch.simulator.ir import IRGraph diff --git a/python/monarch/worker/worker.py b/python/monarch/worker/worker.py index 3eb62c23..89b48924 100644 --- a/python/monarch/worker/worker.py +++ b/python/monarch/worker/worker.py @@ -37,13 +37,13 @@ import torch.fx import zmq import zmq.asyncio +from monarch._src.actor.shape import NDSlice from monarch.common import messages from monarch.common.function import ResolvableFunction from monarch.common.messages import DependentOnError, Dims from monarch.common.process_group import SingleControllerProcessGroupWrapper from monarch.common.reference import Ref, Referenceable -from monarch.common.shape import NDSlice from monarch.common.tensor_factory import TensorFactory from monarch.common.tree import flatten, flattener from monarch_supervisor import get_message_queue, Letter diff --git a/python/monarch/world_mesh.py b/python/monarch/world_mesh.py index b8639c83..52698eb9 100644 --- a/python/monarch/world_mesh.py +++ b/python/monarch/world_mesh.py @@ -8,10 +8,11 @@ from typing import List +from monarch._src.actor.shape import NDSlice + from monarch.common.client import Client from monarch.common.device_mesh import DeviceMesh -from monarch.common.shape import NDSlice from monarch.controller.backend import ProcessBackend diff --git a/python/tests/code_sync/test_auto_reload.py b/python/tests/code_sync/test_auto_reload.py index 98916f7c..4a85b88b 100644 --- a/python/tests/code_sync/test_auto_reload.py +++ b/python/tests/code_sync/test_auto_reload.py @@ -17,7 +17,7 @@ import pytest -from monarch.code_sync.auto_reload import AutoReloader, SysAuditImportHook +from monarch._src.actor.code_sync.auto_reload import AutoReloader, SysAuditImportHook def write_text(path: Path, content: str): diff --git a/python/tests/test_allocator.py b/python/tests/test_allocator.py index 5539fe73..7de4b3b8 100644 --- a/python/tests/test_allocator.py +++ b/python/tests/test_allocator.py @@ -32,13 +32,13 @@ ChannelAddr, ChannelTransport, ) -from monarch.actor_mesh import Actor, current_rank, current_size, endpoint, ValueMesh -from monarch.allocator import ( +from monarch._src.actor.allocator import ( ALLOC_LABEL_PROC_MESH_NAME, RemoteAllocator, StaticRemoteAllocInitializer, TorchXRemoteAllocInitializer, ) +from monarch.actor_mesh import Actor, current_rank, current_size, endpoint, ValueMesh from monarch.proc_mesh import ProcMesh from monarch.tools.mesh_spec import MeshSpec, ServerSpec from monarch.tools.network import get_sockaddr diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index b0124822..c8052a80 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -17,20 +17,20 @@ import pytest import torch +from monarch._src.actor.debugger import init_debugging -from monarch.actor_mesh import ( +from monarch._src.actor.proc_mesh import local_proc_mesh +from monarch.actor import ( Accumulator, Actor, current_actor_name, current_rank, current_size, endpoint, + Future, MonarchContext, + proc_mesh, ) -from monarch.debugger import init_debugging -from monarch.future import ActorFuture - -from monarch.proc_mesh import local_proc_mesh, proc_mesh from monarch.rdma import RDMABuffer needs_cuda = pytest.mark.skipif( @@ -469,9 +469,9 @@ def _patch_output(msg): nonlocal outputs outputs.append(msg) - with patch("monarch.debugger._debugger_input", side_effect=input_mock), patch( - "monarch.debugger._debugger_output", new=_patch_output - ): + with patch( + "monarch._src.actor.debugger._debugger_input", side_effect=input_mock + ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): proc = await proc_mesh(hosts=2, gpus=2) debugee = await proc.spawn("debugee", DebugeeActor) debug_client = await init_debugging(debugee) @@ -545,7 +545,7 @@ def _patch_output(msg): breakpoints = await debug_client.list.call_one() assert len(breakpoints) == 0 - with pytest.raises(monarch.actor_mesh.ActorError, match="ValueError: bad rank"): + with pytest.raises(monarch.actor.ActorError, match="ValueError: bad rank"): await fut @@ -654,13 +654,13 @@ async def incr(): # can use async implementation from sync # if no non-blocking is provided - f = ActorFuture(incr) + f = Future(incr) assert f.get() == 1 assert v == 1 assert f.get() == 1 assert asyncio.run(awaitit(f)) == 1 - f = ActorFuture(incr) + f = Future(incr) assert asyncio.run(awaitit(f)) == 2 assert f.get() == 2 @@ -670,7 +670,7 @@ def incr2(): return v # Use non-blocking optimization if provided - f = ActorFuture(incr, incr2) + f = Future(incr, incr2) assert f.get() == 4 assert asyncio.run(awaitit(f)) == 4 @@ -679,7 +679,7 @@ async def nope(): v += 1 raise ValueError("nope") - f = ActorFuture(nope) + f = Future(nope) with pytest.raises(ValueError): f.get() @@ -701,7 +701,7 @@ def nope(): v += 1 raise ValueError("nope") - f = ActorFuture(incr, nope) + f = Future(incr, nope) with pytest.raises(ValueError): f.get() @@ -723,7 +723,7 @@ def nope(): async def seven(): return 7 - f = ActorFuture(seven) + f = Future(seven) assert 7 == f.get(timeout=0.001) @@ -731,7 +731,7 @@ async def neverfinish(): f = asyncio.Future() await f - f = ActorFuture(neverfinish) + f = Future(neverfinish) with pytest.raises(asyncio.exceptions.TimeoutError): f.get(timeout=0.1)