diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f22c9646..49884121 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -12,7 +12,6 @@ import inspect import logging import random -import sys import traceback from dataclasses import dataclass @@ -53,15 +52,12 @@ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future -from monarch._src.actor.pdb_wrapper import remote_breakpointhook +from monarch._src.actor.pdb_wrapper import PdbWrapper from monarch._src.actor.pickle import flatten, unpickle from monarch._src.actor.shape import MeshTrait, NDSlice -if TYPE_CHECKING: - from monarch._src.actor.debugger import DebugClient - logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -97,6 +93,23 @@ def get() -> "MonarchContext": ) +@dataclass +class DebugContext: + pdb_wrapper: Optional[PdbWrapper] = None + + @staticmethod + def get() -> "DebugContext": + return _debug_context.get() + + @staticmethod + def set(debug_context: "DebugContext") -> None: + _debug_context.set(debug_context) + + +_debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar( + "monarch.actor_mesh._debug_context" +) + T = TypeVar("T") P = ParamSpec("P") R = TypeVar("R") @@ -538,6 +551,8 @@ async def handle_cast( ) _context.set(ctx) + DebugContext.set(DebugContext()) + args, kwargs = unpickle(message.message, mailbox) if message.method == "__init__": @@ -574,9 +589,10 @@ async def instrumented(): ) try: result = await the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() except Exception as e: logging.critical( - "Unahndled exception in actor endpoint", + "Unhandled exception in actor endpoint", exc_info=e, ) raise e @@ -589,11 +605,13 @@ async def instrumented(): the_method.__module__, message.method, str(ctx.mailbox.actor_id) ) result = the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() exit_span() if port is not None: port.send("result", result) except Exception as e: + self._post_mortem_debug(e.__traceback__) traceback.print_exc() s = ActorError(e) @@ -604,6 +622,7 @@ async def instrumented(): else: raise s from None except BaseException as e: + self._post_mortem_debug(e.__traceback__) # A BaseException can be thrown in the case of a Rust panic. # In this case, we need a way to signal the panic to the Rust side. # See [Panics in async endpoints] @@ -614,6 +633,29 @@ async def instrumented(): pass raise + def _maybe_exit_debugger(self, do_continue=True) -> None: + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + if do_continue: + pdb_wrapper.clear_all_breaks() + pdb_wrapper.do_continue("") + pdb_wrapper.end_debug_session() + DebugContext.set(DebugContext()) + + def _post_mortem_debug(self, exc_tb) -> None: + from monarch._src.actor.debugger import DebugManager + + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.post_mortem(exc_tb) + self._maybe_exit_debugger(do_continue=False) + def _is_mailbox(x: object) -> bool: return isinstance(x, Mailbox) @@ -648,19 +690,6 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - @endpoint # pyre-ignore - def _set_debug_client(self, client: "DebugClient") -> None: - point = MonarchContext.get().point - # For some reason, using a lambda instead of functools.partial - # confuses the pdb wrapper implementation. - sys.breakpointhook = functools.partial( # pyre-ignore - remote_breakpointhook, - point.rank, - point.shape.coordinates(point.rank), - MonarchContext.get().mailbox.actor_id, - client, - ) - class ActorMeshRef(MeshTrait, Generic[T]): def __init__( diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py index 5b377ac2..b94a8a96 100644 --- a/python/monarch/_src/actor/bootstrap_main.py +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -54,6 +54,10 @@ def invoke_main(): except Exception as e: logging.warning(f"Failed to set up py-spy: {e}") + from monarch._src.actor.debugger import remote_breakpointhook + + sys.breakpointhook = remote_breakpointhook + # Start an event loop for PythonActors to use. asyncio.run(main()) diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index 2906c63f..ed66601e 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -4,23 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio +import functools +import inspect import logging +import os import sys from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import cast, Dict, Generator, List, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, endpoint -from monarch._src.actor.pdb_wrapper import DebuggerWrite -from monarch._src.actor.proc_mesh import local_proc_mesh +from monarch._src.actor.actor_mesh import ( + _ActorMeshRefImpl, + Actor, + ActorMeshRef, + DebugContext, + endpoint, + MonarchContext, +) +from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper from tabulate import tabulate logger = logging.getLogger(__name__) - -CANCEL_TOKEN = object() +_DEBUG_MANAGER_ACTOR_NAME = "debug_manager" async def _debugger_input(prompt=""): @@ -148,93 +157,170 @@ async def debugger_write(self, write: DebuggerWrite) -> None: await self._message_queue.put(("write", write)) +RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] + + +_debug_input_parser = None + + +# Wrap the parser in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_parser(): + global _debug_input_parser + if _debug_input_parser is None: + from lark import Lark + + _debug_input_parser = Lark( + """ + rank_list: INT "," INT ("," INT)* + start: INT? + stop: INT? + step: INT? + rank_range: start ":" stop (":" step)? + dim: CNAME "=" (rank_range | "(" rank_list ")" | INT) + dims: dim ("," dim)* + ranks: "ranks(" (dims | rank_range | rank_list | INT) ")" + pdb_command: /\\w+.*/ + cast: "cast" ranks pdb_command + help: "h" | "help" + attach: ("a" | "attach") INT + cont: "c" | "continue" + quit: "q" | "quit" + list: "l" | "list" + command: attach | list | cast | help | cont | quit + + %import common.INT + %import common.CNAME + %import common.WS + %ignore WS + """, + start="command", + ) + return _debug_input_parser + + +_debug_input_transformer = None + + +# Wrap the transformer in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_transformer(): + global _debug_input_transformer + if _debug_input_transformer is None: + from lark import Transformer + from lark.lexer import Token + + class _IntoDebugCommandTransformer(Transformer): + def rank_list(self, items: List[Token]) -> List[int]: + return [int(item.value) for item in items] + + def start(self, items: List[Token]) -> int: + if len(items) == 0: + return 0 + return int(items[0].value) + + def stop(self, items: List[Token]) -> int: + if len(items) == 0: + return sys.maxsize + return int(items[0].value) + + def step(self, items: List[Token]) -> int: + if len(items) == 0: + return 1 + return int(items[0].value) + + def rank_range(self, items: List[int]) -> range: + return range(*items) + + def dim( + self, items: Tuple[Token, Union[range, List[int], Token]] + ) -> Tuple[str, Union[range, List[int], int]]: + if isinstance(items[1], range): + return (items[0].value, cast(range, items[1])) + elif isinstance(items[1], list): + return (items[0].value, cast(List[int], items[1])) + else: + return (items[0].value, int(cast(Token, items[1]).value)) + + def dims( + self, items: List[Tuple[str, Union[range, List[int], int]]] + ) -> Dict[str, Union[range, List[int], int]]: + return {dim[0]: dim[1] for dim in items} + + def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType: + if isinstance(items[0], Token): + return int(cast(Token, items[0]).value) + return cast(RanksType, items[0]) + + def pdb_command(self, items: List[Token]) -> str: + return items[0].value + + def help(self, _items: List[Token]) -> "Help": + return Help() + + def attach(self, items: List[Token]) -> "Attach": + return Attach(int(items[0].value)) + + def cont(self, _items: List[Token]) -> "Continue": + return Continue() + + def quit(self, _items: List[Token]) -> "Quit": + return Quit() + + def cast(self, items: Tuple[RanksType, str]) -> "Cast": + return Cast(items[0], items[1]) + + def list(self, items: List[Token]) -> "ListCommand": + return ListCommand() + + def command(self, items: List["DebugCommand"]) -> "DebugCommand": + return items[0] + + _debug_input_transformer = _IntoDebugCommandTransformer() + return _debug_input_transformer + + class DebugCommand: @staticmethod def parse(line: str) -> Union["DebugCommand", None]: - parts = line.strip("\n").split(" ") - if len(parts) == 0: + try: + tree = _get_debug_input_parser().parse(line) + return _get_debug_input_transformer().transform(tree) + except Exception as e: + print(f"Error parsing input: {e}") return None - command = parts[0] - match command: - case "attach": - return Attach._parse(parts) - case "list": - return ListCommand() - case "quit": - return Quit() - case "cast": - return Cast._parse(parts) - case "help": - return Help() - case "continue": - return Continue() - case _: - print( - f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help" - ) - return None @dataclass class Attach(DebugCommand): rank: int - @classmethod - def _parse(cls, parts: List[str]) -> "Attach": - if len(parts) != 2: - raise ValueError("Invalid attach command. Expected: attach ") - try: - rank = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid rank {parts[1]}. Expected: int") - return cls(rank) - +@dataclass class ListCommand(DebugCommand): pass +@dataclass class Quit(DebugCommand): pass +@dataclass class Help(DebugCommand): pass +@dataclass class Continue(DebugCommand): pass @dataclass class Cast(DebugCommand): - ranks: List[int] | None + ranks: RanksType command: str - @classmethod - def _parse(cls, parts: List[str]) -> "Cast": - if len(parts) < 3: - raise ValueError( - "Invalid cast command. Expected: cast { | *} " - ) - str_ranks = parts[1] - command = " ".join(parts[2:]) - if str_ranks == "*": - return cls(None, command) - else: - str_ranks = str_ranks.split(",") - if len(str_ranks) == 0: - raise ValueError( - "Invalid rank list for cast. Expected at least one rank." - ) - ranks = [] - for rank in str_ranks: - try: - ranks.append(int(rank)) - except ValueError: - raise ValueError(f"Invalid rank {rank}. Expected: int") - return cls(ranks, command) - class DebugClient(Actor): """ @@ -253,10 +339,10 @@ async def wait_pending_session(self): @endpoint async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]: - table_data = [] + session_info = [] for _, session in self.sessions.items(): info = session.get_info() - table_data.append( + session_info.append( ( info.rank, info.coords, @@ -266,17 +352,30 @@ async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]] info.lineno, ) ) - table_data = sorted(table_data, key=lambda r: r[0]) - - headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."] - print(tabulate(table_data, headers=headers, tablefmt="grid")) - - return table_data + table_info = sorted(session_info, key=lambda r: r[0]) + print( + tabulate( + table_info, + headers=[ + "Rank", + "Coords", + "Hostname", + "Actor ID", + "Function", + "Line No.", + ], + tablefmt="grid", + ) + ) + return table_info @endpoint async def enter(self) -> None: - # pyre-ignore - await getattr(self, "list")._method(self) # noqa + await asyncio.sleep(0.5) + logger.info("Remote breakpoint hit. Entering monarch debugger...") + print("\n\n************************ MONARCH DEBUGGER ************************") + print("Enter 'help' for a list of commands.") + print("Enter 'list' to show all active breakpoints.\n") while True: try: @@ -288,7 +387,10 @@ async def enter(self) -> None: print("\tlist - list all debug sessions") print("\tquit - exit the debugger, leaving all sessions in place") print( - "\tcast { | *} - send a command to a comma-separated list of ranks, or all ranks" + "\tcast ranks(...) - send a command to a set of ranks.\n" + "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n" + "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n" + "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))." ) print( "\tcontinue - tell all ranks to continue execution, then exit the debugger" @@ -300,41 +402,75 @@ async def enter(self) -> None: else: await self.sessions[command.rank].attach() elif isinstance(command, ListCommand): - await getattr(self, "list")._method(self) # noqa + # pyre-ignore + await self.list._method(self) elif isinstance(command, Continue): - # Make sure all ranks have exited their debug sessions. - # If we sent "quit", it would raise BdbQuit, crashing - # the process, which probably isn't what we want. + # Clear all breakpoints and make sure all ranks have + # exited their debug sessions. If we sent "quit", it + # would raise BdbQuit, crashing the process, which + # probably isn't what we want. + await self._cast_input_and_wait("clear") while len(self.sessions) > 0: - tasks = [] - for rank in self.sessions: - tasks.append( - self.sessions[rank].attach("c", suppress_output=True) - ) - await asyncio.gather(*tasks) + await self._cast_input_and_wait("c") return elif isinstance(command, Quit): return elif isinstance(command, Cast): - if command.ranks is None: - ranks = self.sessions.keys() - else: - ranks = command.ranks - tasks = [] - for rank in ranks: - if rank in self.sessions: - tasks.append( - self.sessions[rank].attach( - command.command, - suppress_output=True, - ) - ) - else: - print(f"No debug session for rank {rank}") - await asyncio.gather(*tasks) + await self._cast_input_and_wait(command.command, command.ranks) except Exception as e: print(f"Error processing command: {e}") + async def _cast_input_and_wait( + self, + command: str, + ranks: RanksType | None = None, + ) -> None: + if ranks is None: + ranks = self.sessions.keys() + elif isinstance(ranks, dict): + ranks = self._iter_ranks_dict(ranks) + elif isinstance(ranks, range): + ranks = self._iter_ranks_range(ranks) + elif isinstance(ranks, int): + ranks = [ranks] + tasks = [] + for rank in ranks: + if rank in self.sessions: + tasks.append( + self.sessions[rank].attach( + command, + suppress_output=True, + ) + ) + else: + print(f"No debug session for rank {rank}") + await asyncio.gather(*tasks) + + def _iter_ranks_dict( + self, dims: Dict[str, Union[range, List[int], int]] + ) -> Generator[int, None, None]: + for rank, session in self.sessions.items(): + include_rank = True + for dim, ranks in dims.items(): + if dim not in session.coords: + include_rank = False + break + elif ( + isinstance(ranks, range) or isinstance(ranks, list) + ) and session.coords[dim] not in ranks: + include_rank = False + break + elif isinstance(ranks, int) and session.coords[dim] != ranks: + include_rank = False + break + if include_rank: + yield rank + + def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]: + for rank in self.sessions.keys(): + if rank in rng: + yield rank + ########################################################################## # Debugger APIs # @@ -368,10 +504,57 @@ async def debugger_write(self, rank: int, write: DebuggerWrite) -> None: await session.debugger_write(write) -async def init_debugging( - actor_mesh: ActorMeshRef, -) -> ActorMeshRef[DebugClient]: - debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1) - debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient) - await actor_mesh._set_debug_client.call(debug_client_mesh) - return debug_client_mesh +class DebugManager(Actor): + @staticmethod + @functools.cache + def ref() -> "DebugManager": + ctx = MonarchContext.get() + return cast( + DebugManager, + ActorMeshRef( + DebugManager, + _ActorMeshRefImpl.from_actor_id( + ctx.mailbox, + ActorId.from_string( + f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]" + ), + ), + ctx.mailbox, + ), + ) + + def __init__(self, debug_client: DebugClient) -> None: + self._debug_client = debug_client + + # pyre-ignore + @endpoint + def get_debug_client(self) -> DebugClient: + return self._debug_client + + +def remote_breakpointhook(): + frame = inspect.currentframe() + assert frame is not None + frame = frame.f_back + assert frame is not None + file = frame.f_code.co_filename + line = frame.f_lineno + module = frame.f_globals.get("__name__", "__main__") + if module == "__main__" and not os.path.exists(file): + raise NotImplementedError( + f"Remote debugging not supported for breakpoint at {file}:{line} because " + f"it is defined inside __main__, and the file does not exist on the host. " + "In this case, cloudpickle serialization does not interact nicely with pdb. " + "To debug your code, move it out of __main__ and into a module that " + "exists on both your client and worker processes." + ) + + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.set_trace(frame) diff --git a/python/monarch/_src/actor/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py index 87031dd4..a9cf56b7 100644 --- a/python/monarch/_src/actor/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import bdb import inspect import io @@ -45,35 +46,38 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def setup(self, *args, **kwargs): - r = super().setup(*args, **kwargs) - if self._first: - self._first = False - # when we enter the debugger, we want to present the user's stack frame - # not the nested one inside session.run. This means that the local - # variables are what gets printed, etc. To do this - # we first execute up 2 to get to that frame. - self.do_up(2) - return r - - def set_continue(self) -> None: - r = super().set_continue() - if not self.breaks: - # no more breakpoints so this debugger will not - # be used again, and we detach from the controller io. - self.client_ref.debugger_session_end.call_one(self.rank).get() - # break cycle with itself before we exit - self.stdin = sys.stdin - self.stdout = sys.stdout - return r - - def set_trace(self): + def set_trace(self, frame): self.client_ref.debugger_session_start.call_one( self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id ).get() if self.header: self.message(self.header) - super().set_trace() + super().set_trace(frame) + + def do_clear(self, arg): + if not arg: + # Sending `clear` without any argument specified will + # request confirmation from the user using the `input` function, + # which bypasses our ReadWrapper and causes a hang on the client. + # To avoid this, we just clear all breakpoints instead without + # confirmation. + super().clear_all_breaks() + else: + super().do_clear(arg) + + def end_debug_session(self): + self.client_ref.debugger_session_end.call_one(self.rank).get() + # Once the debug client actor is notified of the session being over, + # we need to prevent any additional requests being sent for the session + # by redirecting stdin and stdout. + self.stdin = sys.stdin + self.stdout = sys.stdout + + def post_mortem(self, exc_tb): + self._first = False + # See builtin implementation of pdb.post_mortem() for reference. + self.reset() + self.interaction(None, exc_tb) class ReadWrapper(io.RawIOBase): @@ -126,10 +130,3 @@ def write(self, s: str): def flush(self): pass - - -def remote_breakpointhook( - rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient" -): - ds = PdbWrapper(rank, coords, actor_id, client_ref) - ds.set_trace() diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..0df67df9 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -38,6 +38,11 @@ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor +from monarch._src.actor.debugger import ( + _DEBUG_MANAGER_ACTOR_NAME, + DebugClient, + DebugManager, +) from monarch._src.actor.device_utils import _local_device_count from monarch._src.actor.future import Future @@ -83,11 +88,13 @@ def __init__( hy_proc_mesh: HyProcMesh, _mock_shape: Optional[Shape] = None, _device_mesh: Optional["DeviceMesh"] = None, + _is_initializing_debugger: bool = False, ) -> None: self._proc_mesh = hy_proc_mesh self._mock_shape: Optional[Shape] = _mock_shape # type: ignore[21] self._rdma_manager: Optional["RDMAManager"] = None + self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._rsync_mesh_client: Optional[RsyncMeshClient] = None self._auto_reload_actor: Optional[AutoReloadActor] = None @@ -96,6 +103,10 @@ def __init__( if _mock_shape is None and HAS_TENSOR_ENGINE: # type: ignore[21] self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + if not _is_initializing_debugger: + self._debug_manager = self._spawn_blocking( + _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client() + ) @property def _shape(self) -> Shape: @@ -296,13 +307,21 @@ async def local_proc_mesh_nonblocking( return await ProcMesh.from_alloc(alloc) -def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: +def local_proc_mesh_blocking( + *, + gpus: Optional[int] = None, + hosts: int = 1, + _is_initializing_debugger: bool = False, +) -> ProcMesh: if gpus is None: gpus = _local_device_count() spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) allocator = LocalAllocator() alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() + return ProcMesh( + HyProcMesh.allocate_blocking(alloc), + _is_initializing_debugger=_is_initializing_debugger, + ) def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: @@ -371,3 +390,33 @@ def proc_mesh( lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), ) + + +_debug_proc_mesh: Optional["ProcMesh"] = None + + +# Lazy init of the debug proc mesh so that importing monarch.proc_mesh +# doesn't trigger the debug client to spawn, which could cause confusing +# logs. This is defined in proc_mesh.py instead of debugger.py for +# circular import reasons. +def _get_debug_proc_mesh() -> "ProcMesh": + global _debug_proc_mesh + if _debug_proc_mesh is None: + _debug_proc_mesh = local_proc_mesh_blocking( + gpus=1, hosts=1, _is_initializing_debugger=True + ) + return _debug_proc_mesh + + +_debug_client_mesh: Optional[ActorMeshRef[DebugClient]] = None + + +# Lazy init for the same reason as above. This is defined in proc_mesh.py +# instead of debugger.py for circular import reasons. +def debug_client() -> ActorMeshRef[DebugClient]: + global _debug_client_mesh + if _debug_client_mesh is None: + _debug_client_mesh = ( + _get_debug_proc_mesh().spawn("debug_client", DebugClient).get() + ) + return _debug_client_mesh diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..0a108a2b 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,13 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + debug_client, + local_proc_mesh, + proc_mesh, + ProcMesh, +) + __all__ = [ "Accumulator", @@ -42,4 +48,5 @@ "ProcMesh", "send", "ValueMesh", + "debug_client", ] diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py new file mode 100644 index 00000000..a4750d70 --- /dev/null +++ b/python/tests/test_debugger.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import asyncio +import re +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import monarch +import monarch.actor as actor + +import pytest + +import torch + +from monarch._src.actor.actor_mesh import Actor, endpoint, MonarchContext +from monarch._src.actor.debugger import ( + Attach, + Cast, + Continue, + DebugClient, + DebugCommand, + DebugSession, + Help, + ListCommand, + Quit, +) + +from monarch._src.actor.proc_mesh import proc_mesh + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def _bad_rank(): + raise ValueError("bad rank") + + +def _debugee_actor_internal(rank): + if rank == 0: + breakpoint() # noqa + rank += 1 + rank += 1 + return rank + elif rank == 1: + breakpoint() # noqa + rank += 2 + rank += 2 + return rank + elif rank == 2: + breakpoint() # noqa + rank += 3 + rank += 3 + _bad_rank() + elif rank == 3: + breakpoint() # noqa + rank += 4 + rank += 4 + return rank + + +class DebugeeActor(Actor): + @endpoint + async def to_debug(self): + rank = MonarchContext.get().point.rank + return _debugee_actor_internal(rank) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach 1", + "n", + "n", + "n", + "n", + "detach", + "attach 1", + "detach", + "quit", + "cast ranks(0,3) n", + "cast ranks(0,3) n", + # Attaching to 0 and 3 ensures that when we call "list" + # the next time, their function/lineno info will be + # up-to-date. + "attach 0", + "detach", + "attach 3", + "detach", + "quit", + "attach 2", + "c", + "detach", + "quit", + "attach 2", + "bt", + "c", + "quit", + "continue", + ] + + outputs = [] + + def _patch_output(msg): + nonlocal outputs + outputs.append(msg) + + with patch( + "monarch._src.actor.debugger._debugger_input", side_effect=input_mock + ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): + proc = await proc_mesh(hosts=2, gpus=2) + debugee = await proc.spawn("debugee", DebugeeActor) + debug_client = actor.debug_client() + + fut = debugee.to_debug.call() + await debug_client.wait_pending_session.call_one() + breakpoints = [] + for i in range(10): + breakpoints = await debug_client.list.call_one() + if len(breakpoints) == 4: + break + await asyncio.sleep(1) + if i == 9: + raise RuntimeError("timed out waiting for breakpoints") + + initial_linenos = {} + for i in range(len(breakpoints)): + rank, coords, _, _, function, lineno = breakpoints[i] + initial_linenos[rank] = lineno + assert rank == i + assert coords == {"hosts": rank // 2, "gpus": rank % 2} + assert function == "test_debugger._debugee_actor_internal" + assert lineno == breakpoints[0][5] + 5 * rank + + await debug_client.enter.call_one() + + # Check that when detaching and re-attaching to a session, the last portion of the output is repeated + expected_last_output = [ + r"--Return--", + r"\n", + r"> (/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n-> return _debugee_actor_internal\(rank\)", + r"\n", + r"\(Pdb\) ", + ] + output_len = len(expected_last_output) + assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] + for real_output, expected_output in zip( + outputs[-output_len:], expected_last_output + ): + assert re.match(expected_output, real_output) is not None + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + elif i in (0, 3): + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + 2 + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 4 + # Expect post-mortem debugging for rank 2 + assert breakpoints[2][4] == "test_debugger._bad_rank" + + await debug_client.enter.call_one() + + expected_last_output = [ + r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)", + r"\n", + r'> (/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)', + r"\n", + r"\(Pdb\) ", + ] + + for output, expected_output in zip( + outputs[-len(expected_last_output) :], expected_last_output + ): + assert re.match(expected_output, output) is not None + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 3 + for i, rank in enumerate((0, 1, 3)): + assert breakpoints[i][0] == rank + + await debug_client.enter.call_one() + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 0 + + with pytest.raises( + monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank" + ): + await fut + + +async def test_cast_input_and_wait() -> None: + debug_client = DebugClient() + + mock_sessions = {} + for host in range(3): + for gpu in range(8): + rank = host * 8 + gpu + mock_session = MagicMock(spec=DebugSession) + mock_session.attach = AsyncMock() + mock_session.rank = rank + mock_session.coords = {"hosts": host, "gpus": gpu} + mock_sessions[rank] = mock_session + + debug_client.sessions = mock_sessions + + # Cast to a single rank + await debug_client._cast_input_and_wait("n", 2) + mock_sessions[2].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank != 2: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a list of ranks + ranks = [1, 3, 5] + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a range of ranks + ranks = range(2, 24, 3) + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to all ranks + await debug_client._cast_input_and_wait("n", None) + for session in mock_sessions.values(): + session.attach.assert_called_once_with("n", suppress_output=True) + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a single value + await debug_client._cast_input_and_wait("n", {"hosts": 1}) + for session in mock_sessions.values(): + if session.coords["hosts"] == 1: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a list + await debug_client._cast_input_and_wait("n", {"hosts": [0, 2]}) + for _rank, session in mock_sessions.items(): + if session.coords["hosts"] in [0, 2]: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a range + await debug_client._cast_input_and_wait("n", {"gpus": range(5, 8)}) + for session in mock_sessions.values(): + if session.coords["gpus"] in range(5, 8): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using multiple dimension filters + await debug_client._cast_input_and_wait( + "n", {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)} + ) + for session in mock_sessions.values(): + if session.coords["hosts"] in [1, 3] and session.coords["gpus"] in range( + 0, sys.maxsize, 3 + ): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast with non-existent dimension + await debug_client._cast_input_and_wait("n", {"hosts": 0, "gpus": 0, "foo": 0}) + for session in mock_sessions.values(): + session.attach.assert_not_called() + + +@pytest.mark.parametrize( + ["user_input", "expected_output"], + [ + ("attach 1", Attach(1)), + ("a 100", Attach(100)), + ("list", ListCommand()), + ("l", ListCommand()), + ("help", Help()), + ("h", Help()), + ("quit", Quit()), + ("q", Quit()), + ("continue", Continue()), + ("c", Continue()), + ("cast ranks(123) b 25", Cast(ranks=123, command="b 25")), + ("cast ranks(12,34,56) b 25", Cast(ranks=[12, 34, 56], command="b 25")), + ("cast ranks(:) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ("cast ranks(:123) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(123:) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(123:456) b 25", Cast(ranks=range(123, 456), command="b 25")), + ("cast ranks(::) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ( + "cast ranks(::123) b 25", + Cast(ranks=range(0, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123::) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(:123:) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(:456:123) b 25", Cast(ranks=range(0, 456, 123), command="b 25")), + ( + "cast ranks(456::123) b 25", + Cast(ranks=range(456, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123:456:) b 25", Cast(ranks=range(123, 456), command="b 25")), + ( + "cast ranks(456:789:123) b 25", + Cast(ranks=range(456, 789, 123), command="b 25"), + ), + ("cast ranks(dim1=123) up 2", Cast(ranks={"dim1": 123}, command="up 2")), + ( + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", + Cast( + ranks={ + "dim1": 123, + "dim2": [12, 34, 56], + "dim3": range(15, sys.maxsize, 2), + }, + command="up 2", + ), + ), + ], +) +async def test_debug_command_parser_valid_inputs(user_input, expected_output): + assert DebugCommand.parse(user_input) == expected_output + + +@pytest.mark.parametrize( + "invalid_input", + [ + "", + "attch 1", + "attach", + "cast rnks(123) b 25", + "cast ranks() b 25", + "cast ranks(1ab) b 25", + "cast ranks(1,a,3) b 25", + "cast ranks(a:2:4) b 25", + "cast ranks(1,2,3", + "cast ranks(1,2,3)) b 25", + "cast ranks(1,) b 25", + "cast ranks(1,2,) b 25", + "cast ranks(,1,2) b 25", + "cast ranks(1,,2) b 25", + "cast ranks(:::) b 25", + "cast ranks(:123::) b 25", + "cast ranks(1:2:3,4) b 25", + "cast ranks(dim1=) b 25", + "cast ranks(dim1=123, dim2=) b 25", + "cast ranks(dim1=123, dim2=(12,34,56) b 25", + "cast ranks(dim1=123, dim2=(,12,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", + "cast ranks(dim1=123,) b 25", + ], +) +async def test_debug_command_parser_invalid_inputs(invalid_input): + assert DebugCommand.parse(invalid_input) is None diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 7092844e..821872e9 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -4,30 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio import operator -import re import threading import time from types import ModuleType -from unittest.mock import AsyncMock, patch import pytest import torch -from monarch._src.actor.debugger import init_debugging -from monarch._src.actor.proc_mesh import local_proc_mesh from monarch.actor import ( Accumulator, Actor, - ActorError, current_actor_name, current_rank, current_size, endpoint, Future, - MonarchContext, + local_proc_mesh, proc_mesh, ) from monarch.rdma import RDMABuffer @@ -408,146 +404,6 @@ def test_proc_mesh_liveness() -> None: counter.value.call().get() -def _debugee_actor_internal(rank): - if rank == 0: - breakpoint() # noqa - rank += 1 - return rank - elif rank == 1: - breakpoint() # noqa - rank += 2 - return rank - elif rank == 2: - breakpoint() # noqa - rank += 3 - raise ValueError("bad rank") - elif rank == 3: - breakpoint() # noqa - rank += 4 - return rank - - -class DebugeeActor(Actor): - @endpoint - async def to_debug(self): - rank = MonarchContext.get().point.rank - return _debugee_actor_internal(rank) - - -async def test_debug() -> None: - input_mock = AsyncMock() - input_mock.side_effect = [ - "attach 1", - "n", - "n", - "n", - "n", - "detach", - "attach 1", - "detach", - "quit", - "cast 0,3 n", - "cast 0,3 n", - # Attaching to 0 and 3 ensures that when we call "list" - # the next time, their function/lineno info will be - # up-to-date. - "attach 0", - "detach", - "attach 3", - "detach", - "quit", - "attach 2", - "c", - "quit", - "continue", - ] - - outputs = [] - - def _patch_output(msg): - nonlocal outputs - outputs.append(msg) - - with patch( - "monarch._src.actor.debugger._debugger_input", side_effect=input_mock - ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): - proc = await proc_mesh(hosts=2, gpus=2) - debugee = await proc.spawn("debugee", DebugeeActor) - debug_client = await init_debugging(debugee) - - fut = debugee.to_debug.call() - await debug_client.wait_pending_session.call_one() - breakpoints = [] - for i in range(10): - breakpoints = await debug_client.list.call_one() - if len(breakpoints) == 4: - break - await asyncio.sleep(1) - if i == 9: - raise RuntimeError("timed out waiting for breakpoints") - - initial_linenos = {} - for i in range(len(breakpoints)): - rank, coords, _, _, function, lineno = breakpoints[i] - initial_linenos[rank] = lineno - assert rank == i - assert coords == {"hosts": rank // 2, "gpus": rank % 2} - assert function == "test_python_actors._debugee_actor_internal" - assert lineno == breakpoints[0][5] + 4 * rank - - await debug_client.enter.call_one() - - # Check that when detaching and re-attaching to a session, the last portion of the output is repeated - expected_last_output = [ - r"--Return--", - r"\n", - r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)", - r"\n", - r"\(Pdb\) ", - ] - output_len = len(expected_last_output) - assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] - for real_output, expected_output in zip( - outputs[-output_len:], expected_last_output - ): - assert re.match(expected_output, real_output) is not None - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - elif i in (0, 3): - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + 2 - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 3 - for i, rank in enumerate((0, 1, 3)): - assert breakpoints[i][0] == rank - - await debug_client.enter.call_one() - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 0 - - with pytest.raises(ActorError, match="ValueError: bad rank"): - await fut - - class TLSActor(Actor): """An actor that manages thread-local state.""" diff --git a/requirements.txt b/requirements.txt index 4235044e..41a2a230 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy pyre-extensions cloudpickle torchx-nightly +lark