From fae262de53b90c7457ed57c12fe73fc4d8153510 Mon Sep 17 00:00:00 2001 From: Sam Lurye Date: Thu, 10 Jul 2025 14:44:24 -0700 Subject: [PATCH] Big debugging update (#456) Summary: This diff contains several updates to the monarch actor mesh debugging experience. - Improved debug input parsing, with the new `cast` command supporting more sophisticated rank selection grammar: - `cast ranks(3) pdb_command`: send `pdb_command` to rank 3. - `cast ranks(1,3,5) pdb_command`: send `pdb_command` to ranks 1, 3 and 5. - `cast ranks(1:10:2) pdb_command`: send `pdb_command` to the ranks in `range(start=1, stop=10, step=2)`. - `cast ranks(pp=2, dp=(1,3), tp=2:8) pdb_command`: send `pdb_command` to ranks with `pp` dim 2, `dp` dim 1 or 3, and `tp` dim in `range(2,8)`. - The debug client is now automatically registered with an actor mesh when that actor mesh is spawned. This means calling `init_debugging(actor_mesh)` is no longer necessary. - Debugging now works with MAST jobs, by enforcing that breakpoints aren't set in `__main__`, and that the file containing the breakpoint exists on the remote host. - The first requirement is due to how `cloudpickle` works -- if an actor endpoint is defined inside `__main__`, `cloudpickle` will serialize it by value instead of by reference. When the code then runs on the remote host, it thinks the location of the code is the user's local `__main__` file, which confuses pdb, because the file doesn't exist at the same path (or may not exist at all) on the remote host. - The second requirement is due to important parts of `pdb`'s implementation relying on the ability to search for the file being debugged on the remote host's file system. - A debugging session for a specific rank is now forcefully exited once the endpoint finishes execution. This contains the debugging experience within user-authored code. It is also necessary for preventing hangs, because if pdb is allowed to continue indefinitely, then control flow will eventually bubble back up to the main asyncio event loop on the worker, at which point everything breaks. - Hitting a breakpoint now automatically enables post-mortem debugging, so any rank that encounters an exception after hitting a breakpoint will automatically stop at the exception. Attaching the debugger to that rank should then provide an experience like `pdb.post_mortem()`. ## Next steps/gaps I'm aware of (reviewers please read): - Indexing debug sessions by rank isn't sustainable, because two actor meshes may simultaneously hit breakpoints on the same rank and cause a collision inside the debug client. - Entering the debug client should happen automatically, rather than requiring the user to do `await debug_client().enter.call_one()`. - Casting pdb commands should ideally leverage `MeshTrait` rather than reimplementing the selection logic. - If a mesh was reshaped/renamed so that its dimension names aren't `hosts` and `gpus` anymore, the debugger should reflect the new shape/names. - The user should be able to enable post-mortem debugging without having to hit a separate breakpoint first. Reviewed By: zdevito Differential Revision: D77568423 --- python/monarch/_src/actor/actor_mesh.py | 67 +++- python/monarch/_src/actor/bootstrap_main.py | 4 + python/monarch/_src/actor/debugger.py | 395 ++++++++++++++----- python/monarch/_src/actor/pdb_wrapper.py | 59 ++- python/monarch/_src/actor/proc_mesh.py | 53 ++- python/monarch/actor/__init__.py | 9 +- python/tests/test_debugger.py | 415 ++++++++++++++++++++ python/tests/test_python_actors.py | 148 +------ requirements.txt | 1 + 9 files changed, 846 insertions(+), 305 deletions(-) create mode 100644 python/tests/test_debugger.py diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f22c9646..49884121 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -12,7 +12,6 @@ import inspect import logging import random -import sys import traceback from dataclasses import dataclass @@ -53,15 +52,12 @@ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future -from monarch._src.actor.pdb_wrapper import remote_breakpointhook +from monarch._src.actor.pdb_wrapper import PdbWrapper from monarch._src.actor.pickle import flatten, unpickle from monarch._src.actor.shape import MeshTrait, NDSlice -if TYPE_CHECKING: - from monarch._src.actor.debugger import DebugClient - logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -97,6 +93,23 @@ def get() -> "MonarchContext": ) +@dataclass +class DebugContext: + pdb_wrapper: Optional[PdbWrapper] = None + + @staticmethod + def get() -> "DebugContext": + return _debug_context.get() + + @staticmethod + def set(debug_context: "DebugContext") -> None: + _debug_context.set(debug_context) + + +_debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar( + "monarch.actor_mesh._debug_context" +) + T = TypeVar("T") P = ParamSpec("P") R = TypeVar("R") @@ -538,6 +551,8 @@ async def handle_cast( ) _context.set(ctx) + DebugContext.set(DebugContext()) + args, kwargs = unpickle(message.message, mailbox) if message.method == "__init__": @@ -574,9 +589,10 @@ async def instrumented(): ) try: result = await the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() except Exception as e: logging.critical( - "Unahndled exception in actor endpoint", + "Unhandled exception in actor endpoint", exc_info=e, ) raise e @@ -589,11 +605,13 @@ async def instrumented(): the_method.__module__, message.method, str(ctx.mailbox.actor_id) ) result = the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() exit_span() if port is not None: port.send("result", result) except Exception as e: + self._post_mortem_debug(e.__traceback__) traceback.print_exc() s = ActorError(e) @@ -604,6 +622,7 @@ async def instrumented(): else: raise s from None except BaseException as e: + self._post_mortem_debug(e.__traceback__) # A BaseException can be thrown in the case of a Rust panic. # In this case, we need a way to signal the panic to the Rust side. # See [Panics in async endpoints] @@ -614,6 +633,29 @@ async def instrumented(): pass raise + def _maybe_exit_debugger(self, do_continue=True) -> None: + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + if do_continue: + pdb_wrapper.clear_all_breaks() + pdb_wrapper.do_continue("") + pdb_wrapper.end_debug_session() + DebugContext.set(DebugContext()) + + def _post_mortem_debug(self, exc_tb) -> None: + from monarch._src.actor.debugger import DebugManager + + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.post_mortem(exc_tb) + self._maybe_exit_debugger(do_continue=False) + def _is_mailbox(x: object) -> bool: return isinstance(x, Mailbox) @@ -648,19 +690,6 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - @endpoint # pyre-ignore - def _set_debug_client(self, client: "DebugClient") -> None: - point = MonarchContext.get().point - # For some reason, using a lambda instead of functools.partial - # confuses the pdb wrapper implementation. - sys.breakpointhook = functools.partial( # pyre-ignore - remote_breakpointhook, - point.rank, - point.shape.coordinates(point.rank), - MonarchContext.get().mailbox.actor_id, - client, - ) - class ActorMeshRef(MeshTrait, Generic[T]): def __init__( diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py index 5b377ac2..b94a8a96 100644 --- a/python/monarch/_src/actor/bootstrap_main.py +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -54,6 +54,10 @@ def invoke_main(): except Exception as e: logging.warning(f"Failed to set up py-spy: {e}") + from monarch._src.actor.debugger import remote_breakpointhook + + sys.breakpointhook = remote_breakpointhook + # Start an event loop for PythonActors to use. asyncio.run(main()) diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index 2906c63f..ed66601e 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -4,23 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio +import functools +import inspect import logging +import os import sys from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import cast, Dict, Generator, List, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, endpoint -from monarch._src.actor.pdb_wrapper import DebuggerWrite -from monarch._src.actor.proc_mesh import local_proc_mesh +from monarch._src.actor.actor_mesh import ( + _ActorMeshRefImpl, + Actor, + ActorMeshRef, + DebugContext, + endpoint, + MonarchContext, +) +from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper from tabulate import tabulate logger = logging.getLogger(__name__) - -CANCEL_TOKEN = object() +_DEBUG_MANAGER_ACTOR_NAME = "debug_manager" async def _debugger_input(prompt=""): @@ -148,93 +157,170 @@ async def debugger_write(self, write: DebuggerWrite) -> None: await self._message_queue.put(("write", write)) +RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] + + +_debug_input_parser = None + + +# Wrap the parser in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_parser(): + global _debug_input_parser + if _debug_input_parser is None: + from lark import Lark + + _debug_input_parser = Lark( + """ + rank_list: INT "," INT ("," INT)* + start: INT? + stop: INT? + step: INT? + rank_range: start ":" stop (":" step)? + dim: CNAME "=" (rank_range | "(" rank_list ")" | INT) + dims: dim ("," dim)* + ranks: "ranks(" (dims | rank_range | rank_list | INT) ")" + pdb_command: /\\w+.*/ + cast: "cast" ranks pdb_command + help: "h" | "help" + attach: ("a" | "attach") INT + cont: "c" | "continue" + quit: "q" | "quit" + list: "l" | "list" + command: attach | list | cast | help | cont | quit + + %import common.INT + %import common.CNAME + %import common.WS + %ignore WS + """, + start="command", + ) + return _debug_input_parser + + +_debug_input_transformer = None + + +# Wrap the transformer in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_transformer(): + global _debug_input_transformer + if _debug_input_transformer is None: + from lark import Transformer + from lark.lexer import Token + + class _IntoDebugCommandTransformer(Transformer): + def rank_list(self, items: List[Token]) -> List[int]: + return [int(item.value) for item in items] + + def start(self, items: List[Token]) -> int: + if len(items) == 0: + return 0 + return int(items[0].value) + + def stop(self, items: List[Token]) -> int: + if len(items) == 0: + return sys.maxsize + return int(items[0].value) + + def step(self, items: List[Token]) -> int: + if len(items) == 0: + return 1 + return int(items[0].value) + + def rank_range(self, items: List[int]) -> range: + return range(*items) + + def dim( + self, items: Tuple[Token, Union[range, List[int], Token]] + ) -> Tuple[str, Union[range, List[int], int]]: + if isinstance(items[1], range): + return (items[0].value, cast(range, items[1])) + elif isinstance(items[1], list): + return (items[0].value, cast(List[int], items[1])) + else: + return (items[0].value, int(cast(Token, items[1]).value)) + + def dims( + self, items: List[Tuple[str, Union[range, List[int], int]]] + ) -> Dict[str, Union[range, List[int], int]]: + return {dim[0]: dim[1] for dim in items} + + def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType: + if isinstance(items[0], Token): + return int(cast(Token, items[0]).value) + return cast(RanksType, items[0]) + + def pdb_command(self, items: List[Token]) -> str: + return items[0].value + + def help(self, _items: List[Token]) -> "Help": + return Help() + + def attach(self, items: List[Token]) -> "Attach": + return Attach(int(items[0].value)) + + def cont(self, _items: List[Token]) -> "Continue": + return Continue() + + def quit(self, _items: List[Token]) -> "Quit": + return Quit() + + def cast(self, items: Tuple[RanksType, str]) -> "Cast": + return Cast(items[0], items[1]) + + def list(self, items: List[Token]) -> "ListCommand": + return ListCommand() + + def command(self, items: List["DebugCommand"]) -> "DebugCommand": + return items[0] + + _debug_input_transformer = _IntoDebugCommandTransformer() + return _debug_input_transformer + + class DebugCommand: @staticmethod def parse(line: str) -> Union["DebugCommand", None]: - parts = line.strip("\n").split(" ") - if len(parts) == 0: + try: + tree = _get_debug_input_parser().parse(line) + return _get_debug_input_transformer().transform(tree) + except Exception as e: + print(f"Error parsing input: {e}") return None - command = parts[0] - match command: - case "attach": - return Attach._parse(parts) - case "list": - return ListCommand() - case "quit": - return Quit() - case "cast": - return Cast._parse(parts) - case "help": - return Help() - case "continue": - return Continue() - case _: - print( - f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help" - ) - return None @dataclass class Attach(DebugCommand): rank: int - @classmethod - def _parse(cls, parts: List[str]) -> "Attach": - if len(parts) != 2: - raise ValueError("Invalid attach command. Expected: attach ") - try: - rank = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid rank {parts[1]}. Expected: int") - return cls(rank) - +@dataclass class ListCommand(DebugCommand): pass +@dataclass class Quit(DebugCommand): pass +@dataclass class Help(DebugCommand): pass +@dataclass class Continue(DebugCommand): pass @dataclass class Cast(DebugCommand): - ranks: List[int] | None + ranks: RanksType command: str - @classmethod - def _parse(cls, parts: List[str]) -> "Cast": - if len(parts) < 3: - raise ValueError( - "Invalid cast command. Expected: cast { | *} " - ) - str_ranks = parts[1] - command = " ".join(parts[2:]) - if str_ranks == "*": - return cls(None, command) - else: - str_ranks = str_ranks.split(",") - if len(str_ranks) == 0: - raise ValueError( - "Invalid rank list for cast. Expected at least one rank." - ) - ranks = [] - for rank in str_ranks: - try: - ranks.append(int(rank)) - except ValueError: - raise ValueError(f"Invalid rank {rank}. Expected: int") - return cls(ranks, command) - class DebugClient(Actor): """ @@ -253,10 +339,10 @@ async def wait_pending_session(self): @endpoint async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]: - table_data = [] + session_info = [] for _, session in self.sessions.items(): info = session.get_info() - table_data.append( + session_info.append( ( info.rank, info.coords, @@ -266,17 +352,30 @@ async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]] info.lineno, ) ) - table_data = sorted(table_data, key=lambda r: r[0]) - - headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."] - print(tabulate(table_data, headers=headers, tablefmt="grid")) - - return table_data + table_info = sorted(session_info, key=lambda r: r[0]) + print( + tabulate( + table_info, + headers=[ + "Rank", + "Coords", + "Hostname", + "Actor ID", + "Function", + "Line No.", + ], + tablefmt="grid", + ) + ) + return table_info @endpoint async def enter(self) -> None: - # pyre-ignore - await getattr(self, "list")._method(self) # noqa + await asyncio.sleep(0.5) + logger.info("Remote breakpoint hit. Entering monarch debugger...") + print("\n\n************************ MONARCH DEBUGGER ************************") + print("Enter 'help' for a list of commands.") + print("Enter 'list' to show all active breakpoints.\n") while True: try: @@ -288,7 +387,10 @@ async def enter(self) -> None: print("\tlist - list all debug sessions") print("\tquit - exit the debugger, leaving all sessions in place") print( - "\tcast { | *} - send a command to a comma-separated list of ranks, or all ranks" + "\tcast ranks(...) - send a command to a set of ranks.\n" + "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n" + "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n" + "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))." ) print( "\tcontinue - tell all ranks to continue execution, then exit the debugger" @@ -300,41 +402,75 @@ async def enter(self) -> None: else: await self.sessions[command.rank].attach() elif isinstance(command, ListCommand): - await getattr(self, "list")._method(self) # noqa + # pyre-ignore + await self.list._method(self) elif isinstance(command, Continue): - # Make sure all ranks have exited their debug sessions. - # If we sent "quit", it would raise BdbQuit, crashing - # the process, which probably isn't what we want. + # Clear all breakpoints and make sure all ranks have + # exited their debug sessions. If we sent "quit", it + # would raise BdbQuit, crashing the process, which + # probably isn't what we want. + await self._cast_input_and_wait("clear") while len(self.sessions) > 0: - tasks = [] - for rank in self.sessions: - tasks.append( - self.sessions[rank].attach("c", suppress_output=True) - ) - await asyncio.gather(*tasks) + await self._cast_input_and_wait("c") return elif isinstance(command, Quit): return elif isinstance(command, Cast): - if command.ranks is None: - ranks = self.sessions.keys() - else: - ranks = command.ranks - tasks = [] - for rank in ranks: - if rank in self.sessions: - tasks.append( - self.sessions[rank].attach( - command.command, - suppress_output=True, - ) - ) - else: - print(f"No debug session for rank {rank}") - await asyncio.gather(*tasks) + await self._cast_input_and_wait(command.command, command.ranks) except Exception as e: print(f"Error processing command: {e}") + async def _cast_input_and_wait( + self, + command: str, + ranks: RanksType | None = None, + ) -> None: + if ranks is None: + ranks = self.sessions.keys() + elif isinstance(ranks, dict): + ranks = self._iter_ranks_dict(ranks) + elif isinstance(ranks, range): + ranks = self._iter_ranks_range(ranks) + elif isinstance(ranks, int): + ranks = [ranks] + tasks = [] + for rank in ranks: + if rank in self.sessions: + tasks.append( + self.sessions[rank].attach( + command, + suppress_output=True, + ) + ) + else: + print(f"No debug session for rank {rank}") + await asyncio.gather(*tasks) + + def _iter_ranks_dict( + self, dims: Dict[str, Union[range, List[int], int]] + ) -> Generator[int, None, None]: + for rank, session in self.sessions.items(): + include_rank = True + for dim, ranks in dims.items(): + if dim not in session.coords: + include_rank = False + break + elif ( + isinstance(ranks, range) or isinstance(ranks, list) + ) and session.coords[dim] not in ranks: + include_rank = False + break + elif isinstance(ranks, int) and session.coords[dim] != ranks: + include_rank = False + break + if include_rank: + yield rank + + def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]: + for rank in self.sessions.keys(): + if rank in rng: + yield rank + ########################################################################## # Debugger APIs # @@ -368,10 +504,57 @@ async def debugger_write(self, rank: int, write: DebuggerWrite) -> None: await session.debugger_write(write) -async def init_debugging( - actor_mesh: ActorMeshRef, -) -> ActorMeshRef[DebugClient]: - debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1) - debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient) - await actor_mesh._set_debug_client.call(debug_client_mesh) - return debug_client_mesh +class DebugManager(Actor): + @staticmethod + @functools.cache + def ref() -> "DebugManager": + ctx = MonarchContext.get() + return cast( + DebugManager, + ActorMeshRef( + DebugManager, + _ActorMeshRefImpl.from_actor_id( + ctx.mailbox, + ActorId.from_string( + f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]" + ), + ), + ctx.mailbox, + ), + ) + + def __init__(self, debug_client: DebugClient) -> None: + self._debug_client = debug_client + + # pyre-ignore + @endpoint + def get_debug_client(self) -> DebugClient: + return self._debug_client + + +def remote_breakpointhook(): + frame = inspect.currentframe() + assert frame is not None + frame = frame.f_back + assert frame is not None + file = frame.f_code.co_filename + line = frame.f_lineno + module = frame.f_globals.get("__name__", "__main__") + if module == "__main__" and not os.path.exists(file): + raise NotImplementedError( + f"Remote debugging not supported for breakpoint at {file}:{line} because " + f"it is defined inside __main__, and the file does not exist on the host. " + "In this case, cloudpickle serialization does not interact nicely with pdb. " + "To debug your code, move it out of __main__ and into a module that " + "exists on both your client and worker processes." + ) + + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.set_trace(frame) diff --git a/python/monarch/_src/actor/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py index 87031dd4..a9cf56b7 100644 --- a/python/monarch/_src/actor/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import bdb import inspect import io @@ -45,35 +46,38 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def setup(self, *args, **kwargs): - r = super().setup(*args, **kwargs) - if self._first: - self._first = False - # when we enter the debugger, we want to present the user's stack frame - # not the nested one inside session.run. This means that the local - # variables are what gets printed, etc. To do this - # we first execute up 2 to get to that frame. - self.do_up(2) - return r - - def set_continue(self) -> None: - r = super().set_continue() - if not self.breaks: - # no more breakpoints so this debugger will not - # be used again, and we detach from the controller io. - self.client_ref.debugger_session_end.call_one(self.rank).get() - # break cycle with itself before we exit - self.stdin = sys.stdin - self.stdout = sys.stdout - return r - - def set_trace(self): + def set_trace(self, frame): self.client_ref.debugger_session_start.call_one( self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id ).get() if self.header: self.message(self.header) - super().set_trace() + super().set_trace(frame) + + def do_clear(self, arg): + if not arg: + # Sending `clear` without any argument specified will + # request confirmation from the user using the `input` function, + # which bypasses our ReadWrapper and causes a hang on the client. + # To avoid this, we just clear all breakpoints instead without + # confirmation. + super().clear_all_breaks() + else: + super().do_clear(arg) + + def end_debug_session(self): + self.client_ref.debugger_session_end.call_one(self.rank).get() + # Once the debug client actor is notified of the session being over, + # we need to prevent any additional requests being sent for the session + # by redirecting stdin and stdout. + self.stdin = sys.stdin + self.stdout = sys.stdout + + def post_mortem(self, exc_tb): + self._first = False + # See builtin implementation of pdb.post_mortem() for reference. + self.reset() + self.interaction(None, exc_tb) class ReadWrapper(io.RawIOBase): @@ -126,10 +130,3 @@ def write(self, s: str): def flush(self): pass - - -def remote_breakpointhook( - rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient" -): - ds = PdbWrapper(rank, coords, actor_id, client_ref) - ds.set_trace() diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..0df67df9 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -38,6 +38,11 @@ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor +from monarch._src.actor.debugger import ( + _DEBUG_MANAGER_ACTOR_NAME, + DebugClient, + DebugManager, +) from monarch._src.actor.device_utils import _local_device_count from monarch._src.actor.future import Future @@ -83,11 +88,13 @@ def __init__( hy_proc_mesh: HyProcMesh, _mock_shape: Optional[Shape] = None, _device_mesh: Optional["DeviceMesh"] = None, + _is_initializing_debugger: bool = False, ) -> None: self._proc_mesh = hy_proc_mesh self._mock_shape: Optional[Shape] = _mock_shape # type: ignore[21] self._rdma_manager: Optional["RDMAManager"] = None + self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._rsync_mesh_client: Optional[RsyncMeshClient] = None self._auto_reload_actor: Optional[AutoReloadActor] = None @@ -96,6 +103,10 @@ def __init__( if _mock_shape is None and HAS_TENSOR_ENGINE: # type: ignore[21] self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + if not _is_initializing_debugger: + self._debug_manager = self._spawn_blocking( + _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client() + ) @property def _shape(self) -> Shape: @@ -296,13 +307,21 @@ async def local_proc_mesh_nonblocking( return await ProcMesh.from_alloc(alloc) -def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: +def local_proc_mesh_blocking( + *, + gpus: Optional[int] = None, + hosts: int = 1, + _is_initializing_debugger: bool = False, +) -> ProcMesh: if gpus is None: gpus = _local_device_count() spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) allocator = LocalAllocator() alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() + return ProcMesh( + HyProcMesh.allocate_blocking(alloc), + _is_initializing_debugger=_is_initializing_debugger, + ) def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: @@ -371,3 +390,33 @@ def proc_mesh( lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), ) + + +_debug_proc_mesh: Optional["ProcMesh"] = None + + +# Lazy init of the debug proc mesh so that importing monarch.proc_mesh +# doesn't trigger the debug client to spawn, which could cause confusing +# logs. This is defined in proc_mesh.py instead of debugger.py for +# circular import reasons. +def _get_debug_proc_mesh() -> "ProcMesh": + global _debug_proc_mesh + if _debug_proc_mesh is None: + _debug_proc_mesh = local_proc_mesh_blocking( + gpus=1, hosts=1, _is_initializing_debugger=True + ) + return _debug_proc_mesh + + +_debug_client_mesh: Optional[ActorMeshRef[DebugClient]] = None + + +# Lazy init for the same reason as above. This is defined in proc_mesh.py +# instead of debugger.py for circular import reasons. +def debug_client() -> ActorMeshRef[DebugClient]: + global _debug_client_mesh + if _debug_client_mesh is None: + _debug_client_mesh = ( + _get_debug_proc_mesh().spawn("debug_client", DebugClient).get() + ) + return _debug_client_mesh diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..0a108a2b 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,13 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + debug_client, + local_proc_mesh, + proc_mesh, + ProcMesh, +) + __all__ = [ "Accumulator", @@ -42,4 +48,5 @@ "ProcMesh", "send", "ValueMesh", + "debug_client", ] diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py new file mode 100644 index 00000000..a4750d70 --- /dev/null +++ b/python/tests/test_debugger.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import asyncio +import re +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import monarch +import monarch.actor as actor + +import pytest + +import torch + +from monarch._src.actor.actor_mesh import Actor, endpoint, MonarchContext +from monarch._src.actor.debugger import ( + Attach, + Cast, + Continue, + DebugClient, + DebugCommand, + DebugSession, + Help, + ListCommand, + Quit, +) + +from monarch._src.actor.proc_mesh import proc_mesh + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def _bad_rank(): + raise ValueError("bad rank") + + +def _debugee_actor_internal(rank): + if rank == 0: + breakpoint() # noqa + rank += 1 + rank += 1 + return rank + elif rank == 1: + breakpoint() # noqa + rank += 2 + rank += 2 + return rank + elif rank == 2: + breakpoint() # noqa + rank += 3 + rank += 3 + _bad_rank() + elif rank == 3: + breakpoint() # noqa + rank += 4 + rank += 4 + return rank + + +class DebugeeActor(Actor): + @endpoint + async def to_debug(self): + rank = MonarchContext.get().point.rank + return _debugee_actor_internal(rank) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach 1", + "n", + "n", + "n", + "n", + "detach", + "attach 1", + "detach", + "quit", + "cast ranks(0,3) n", + "cast ranks(0,3) n", + # Attaching to 0 and 3 ensures that when we call "list" + # the next time, their function/lineno info will be + # up-to-date. + "attach 0", + "detach", + "attach 3", + "detach", + "quit", + "attach 2", + "c", + "detach", + "quit", + "attach 2", + "bt", + "c", + "quit", + "continue", + ] + + outputs = [] + + def _patch_output(msg): + nonlocal outputs + outputs.append(msg) + + with patch( + "monarch._src.actor.debugger._debugger_input", side_effect=input_mock + ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): + proc = await proc_mesh(hosts=2, gpus=2) + debugee = await proc.spawn("debugee", DebugeeActor) + debug_client = actor.debug_client() + + fut = debugee.to_debug.call() + await debug_client.wait_pending_session.call_one() + breakpoints = [] + for i in range(10): + breakpoints = await debug_client.list.call_one() + if len(breakpoints) == 4: + break + await asyncio.sleep(1) + if i == 9: + raise RuntimeError("timed out waiting for breakpoints") + + initial_linenos = {} + for i in range(len(breakpoints)): + rank, coords, _, _, function, lineno = breakpoints[i] + initial_linenos[rank] = lineno + assert rank == i + assert coords == {"hosts": rank // 2, "gpus": rank % 2} + assert function == "test_debugger._debugee_actor_internal" + assert lineno == breakpoints[0][5] + 5 * rank + + await debug_client.enter.call_one() + + # Check that when detaching and re-attaching to a session, the last portion of the output is repeated + expected_last_output = [ + r"--Return--", + r"\n", + r"> (/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n-> return _debugee_actor_internal\(rank\)", + r"\n", + r"\(Pdb\) ", + ] + output_len = len(expected_last_output) + assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] + for real_output, expected_output in zip( + outputs[-output_len:], expected_last_output + ): + assert re.match(expected_output, real_output) is not None + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + elif i in (0, 3): + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + 2 + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 4 + # Expect post-mortem debugging for rank 2 + assert breakpoints[2][4] == "test_debugger._bad_rank" + + await debug_client.enter.call_one() + + expected_last_output = [ + r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)", + r"\n", + r'> (/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)', + r"\n", + r"\(Pdb\) ", + ] + + for output, expected_output in zip( + outputs[-len(expected_last_output) :], expected_last_output + ): + assert re.match(expected_output, output) is not None + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 3 + for i, rank in enumerate((0, 1, 3)): + assert breakpoints[i][0] == rank + + await debug_client.enter.call_one() + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 0 + + with pytest.raises( + monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank" + ): + await fut + + +async def test_cast_input_and_wait() -> None: + debug_client = DebugClient() + + mock_sessions = {} + for host in range(3): + for gpu in range(8): + rank = host * 8 + gpu + mock_session = MagicMock(spec=DebugSession) + mock_session.attach = AsyncMock() + mock_session.rank = rank + mock_session.coords = {"hosts": host, "gpus": gpu} + mock_sessions[rank] = mock_session + + debug_client.sessions = mock_sessions + + # Cast to a single rank + await debug_client._cast_input_and_wait("n", 2) + mock_sessions[2].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank != 2: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a list of ranks + ranks = [1, 3, 5] + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a range of ranks + ranks = range(2, 24, 3) + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to all ranks + await debug_client._cast_input_and_wait("n", None) + for session in mock_sessions.values(): + session.attach.assert_called_once_with("n", suppress_output=True) + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a single value + await debug_client._cast_input_and_wait("n", {"hosts": 1}) + for session in mock_sessions.values(): + if session.coords["hosts"] == 1: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a list + await debug_client._cast_input_and_wait("n", {"hosts": [0, 2]}) + for _rank, session in mock_sessions.items(): + if session.coords["hosts"] in [0, 2]: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a range + await debug_client._cast_input_and_wait("n", {"gpus": range(5, 8)}) + for session in mock_sessions.values(): + if session.coords["gpus"] in range(5, 8): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using multiple dimension filters + await debug_client._cast_input_and_wait( + "n", {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)} + ) + for session in mock_sessions.values(): + if session.coords["hosts"] in [1, 3] and session.coords["gpus"] in range( + 0, sys.maxsize, 3 + ): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast with non-existent dimension + await debug_client._cast_input_and_wait("n", {"hosts": 0, "gpus": 0, "foo": 0}) + for session in mock_sessions.values(): + session.attach.assert_not_called() + + +@pytest.mark.parametrize( + ["user_input", "expected_output"], + [ + ("attach 1", Attach(1)), + ("a 100", Attach(100)), + ("list", ListCommand()), + ("l", ListCommand()), + ("help", Help()), + ("h", Help()), + ("quit", Quit()), + ("q", Quit()), + ("continue", Continue()), + ("c", Continue()), + ("cast ranks(123) b 25", Cast(ranks=123, command="b 25")), + ("cast ranks(12,34,56) b 25", Cast(ranks=[12, 34, 56], command="b 25")), + ("cast ranks(:) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ("cast ranks(:123) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(123:) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(123:456) b 25", Cast(ranks=range(123, 456), command="b 25")), + ("cast ranks(::) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ( + "cast ranks(::123) b 25", + Cast(ranks=range(0, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123::) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(:123:) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(:456:123) b 25", Cast(ranks=range(0, 456, 123), command="b 25")), + ( + "cast ranks(456::123) b 25", + Cast(ranks=range(456, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123:456:) b 25", Cast(ranks=range(123, 456), command="b 25")), + ( + "cast ranks(456:789:123) b 25", + Cast(ranks=range(456, 789, 123), command="b 25"), + ), + ("cast ranks(dim1=123) up 2", Cast(ranks={"dim1": 123}, command="up 2")), + ( + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", + Cast( + ranks={ + "dim1": 123, + "dim2": [12, 34, 56], + "dim3": range(15, sys.maxsize, 2), + }, + command="up 2", + ), + ), + ], +) +async def test_debug_command_parser_valid_inputs(user_input, expected_output): + assert DebugCommand.parse(user_input) == expected_output + + +@pytest.mark.parametrize( + "invalid_input", + [ + "", + "attch 1", + "attach", + "cast rnks(123) b 25", + "cast ranks() b 25", + "cast ranks(1ab) b 25", + "cast ranks(1,a,3) b 25", + "cast ranks(a:2:4) b 25", + "cast ranks(1,2,3", + "cast ranks(1,2,3)) b 25", + "cast ranks(1,) b 25", + "cast ranks(1,2,) b 25", + "cast ranks(,1,2) b 25", + "cast ranks(1,,2) b 25", + "cast ranks(:::) b 25", + "cast ranks(:123::) b 25", + "cast ranks(1:2:3,4) b 25", + "cast ranks(dim1=) b 25", + "cast ranks(dim1=123, dim2=) b 25", + "cast ranks(dim1=123, dim2=(12,34,56) b 25", + "cast ranks(dim1=123, dim2=(,12,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", + "cast ranks(dim1=123,) b 25", + ], +) +async def test_debug_command_parser_invalid_inputs(invalid_input): + assert DebugCommand.parse(invalid_input) is None diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 7092844e..821872e9 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -4,30 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio import operator -import re import threading import time from types import ModuleType -from unittest.mock import AsyncMock, patch import pytest import torch -from monarch._src.actor.debugger import init_debugging -from monarch._src.actor.proc_mesh import local_proc_mesh from monarch.actor import ( Accumulator, Actor, - ActorError, current_actor_name, current_rank, current_size, endpoint, Future, - MonarchContext, + local_proc_mesh, proc_mesh, ) from monarch.rdma import RDMABuffer @@ -408,146 +404,6 @@ def test_proc_mesh_liveness() -> None: counter.value.call().get() -def _debugee_actor_internal(rank): - if rank == 0: - breakpoint() # noqa - rank += 1 - return rank - elif rank == 1: - breakpoint() # noqa - rank += 2 - return rank - elif rank == 2: - breakpoint() # noqa - rank += 3 - raise ValueError("bad rank") - elif rank == 3: - breakpoint() # noqa - rank += 4 - return rank - - -class DebugeeActor(Actor): - @endpoint - async def to_debug(self): - rank = MonarchContext.get().point.rank - return _debugee_actor_internal(rank) - - -async def test_debug() -> None: - input_mock = AsyncMock() - input_mock.side_effect = [ - "attach 1", - "n", - "n", - "n", - "n", - "detach", - "attach 1", - "detach", - "quit", - "cast 0,3 n", - "cast 0,3 n", - # Attaching to 0 and 3 ensures that when we call "list" - # the next time, their function/lineno info will be - # up-to-date. - "attach 0", - "detach", - "attach 3", - "detach", - "quit", - "attach 2", - "c", - "quit", - "continue", - ] - - outputs = [] - - def _patch_output(msg): - nonlocal outputs - outputs.append(msg) - - with patch( - "monarch._src.actor.debugger._debugger_input", side_effect=input_mock - ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): - proc = await proc_mesh(hosts=2, gpus=2) - debugee = await proc.spawn("debugee", DebugeeActor) - debug_client = await init_debugging(debugee) - - fut = debugee.to_debug.call() - await debug_client.wait_pending_session.call_one() - breakpoints = [] - for i in range(10): - breakpoints = await debug_client.list.call_one() - if len(breakpoints) == 4: - break - await asyncio.sleep(1) - if i == 9: - raise RuntimeError("timed out waiting for breakpoints") - - initial_linenos = {} - for i in range(len(breakpoints)): - rank, coords, _, _, function, lineno = breakpoints[i] - initial_linenos[rank] = lineno - assert rank == i - assert coords == {"hosts": rank // 2, "gpus": rank % 2} - assert function == "test_python_actors._debugee_actor_internal" - assert lineno == breakpoints[0][5] + 4 * rank - - await debug_client.enter.call_one() - - # Check that when detaching and re-attaching to a session, the last portion of the output is repeated - expected_last_output = [ - r"--Return--", - r"\n", - r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)", - r"\n", - r"\(Pdb\) ", - ] - output_len = len(expected_last_output) - assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] - for real_output, expected_output in zip( - outputs[-output_len:], expected_last_output - ): - assert re.match(expected_output, real_output) is not None - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - elif i in (0, 3): - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + 2 - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 3 - for i, rank in enumerate((0, 1, 3)): - assert breakpoints[i][0] == rank - - await debug_client.enter.call_one() - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 0 - - with pytest.raises(ActorError, match="ValueError: bad rank"): - await fut - - class TLSActor(Actor): """An actor that manages thread-local state.""" diff --git a/requirements.txt b/requirements.txt index 4235044e..41a2a230 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy pyre-extensions cloudpickle torchx-nightly +lark