From a06f15f7e84aad826b467c985491b6647980e65f Mon Sep 17 00:00:00 2001 From: slurye Date: Thu, 10 Jul 2025 14:44:01 -0700 Subject: [PATCH 1/2] Big debugging update (#456) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/456 This diff contains several updates to the monarch actor mesh debugging experience. - Improved debug input parsing, with the new `cast` command supporting more sophisticated rank selection grammar: - `cast ranks(3) pdb_command`: send `pdb_command` to rank 3. - `cast ranks(1,3,5) pdb_command`: send `pdb_command` to ranks 1, 3 and 5. - `cast ranks(1:10:2) pdb_command`: send `pdb_command` to the ranks in `range(start=1, stop=10, step=2)`. - `cast ranks(pp=2, dp=(1,3), tp=2:8) pdb_command`: send `pdb_command` to ranks with `pp` dim 2, `dp` dim 1 or 3, and `tp` dim in `range(2,8)`. - The debug client is now automatically registered with an actor mesh when that actor mesh is spawned. This means calling `init_debugging(actor_mesh)` is no longer necessary. - Debugging now works with MAST jobs, by enforcing that breakpoints aren't set in `__main__`, and that the file containing the breakpoint exists on the remote host. - The first requirement is due to how `cloudpickle` works -- if an actor endpoint is defined inside `__main__`, `cloudpickle` will serialize it by value instead of by reference. When the code then runs on the remote host, it thinks the location of the code is the user's local `__main__` file, which confuses pdb, because the file doesn't exist at the same path (or may not exist at all) on the remote host. - The second requirement is due to important parts of `pdb`'s implementation relying on the ability to search for the file being debugged on the remote host's file system. - A debugging session for a specific rank is now forcefully exited once the endpoint finishes execution. This contains the debugging experience within user-authored code. It is also necessary for preventing hangs, because if pdb is allowed to continue indefinitely, then control flow will eventually bubble back up to the main asyncio event loop on the worker, at which point everything breaks. - Hitting a breakpoint now automatically enables post-mortem debugging, so any rank that encounters an exception after hitting a breakpoint will automatically stop at the exception. Attaching the debugger to that rank should then provide an experience like `pdb.post_mortem()`. ## Next steps/gaps I'm aware of (reviewers please read): - Indexing debug sessions by rank isn't sustainable, because two actor meshes may simultaneously hit breakpoints on the same rank and cause a collision inside the debug client. - Entering the debug client should happen automatically, rather than requiring the user to do `await debug_client().enter.call_one()`. - Casting pdb commands should ideally leverage `MeshTrait` rather than reimplementing the selection logic. - If a mesh was reshaped/renamed so that its dimension names aren't `hosts` and `gpus` anymore, the debugger should reflect the new shape/names. - The user should be able to enable post-mortem debugging without having to hit a separate breakpoint first. Differential Revision: D77568423 Reviewed By: zdevito --- python/monarch/_src/actor/actor_mesh.py | 67 +++- python/monarch/_src/actor/bootstrap_main.py | 4 + python/monarch/_src/actor/debugger.py | 395 ++++++++++++++----- python/monarch/_src/actor/pdb_wrapper.py | 59 ++- python/monarch/_src/actor/proc_mesh.py | 53 ++- python/monarch/actor/__init__.py | 9 +- python/tests/test_debugger.py | 415 ++++++++++++++++++++ python/tests/test_python_actors.py | 148 +------ requirements.txt | 1 + 9 files changed, 846 insertions(+), 305 deletions(-) create mode 100644 python/tests/test_debugger.py diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f22c9646..49884121 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -12,7 +12,6 @@ import inspect import logging import random -import sys import traceback from dataclasses import dataclass @@ -53,15 +52,12 @@ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future -from monarch._src.actor.pdb_wrapper import remote_breakpointhook +from monarch._src.actor.pdb_wrapper import PdbWrapper from monarch._src.actor.pickle import flatten, unpickle from monarch._src.actor.shape import MeshTrait, NDSlice -if TYPE_CHECKING: - from monarch._src.actor.debugger import DebugClient - logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -97,6 +93,23 @@ def get() -> "MonarchContext": ) +@dataclass +class DebugContext: + pdb_wrapper: Optional[PdbWrapper] = None + + @staticmethod + def get() -> "DebugContext": + return _debug_context.get() + + @staticmethod + def set(debug_context: "DebugContext") -> None: + _debug_context.set(debug_context) + + +_debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar( + "monarch.actor_mesh._debug_context" +) + T = TypeVar("T") P = ParamSpec("P") R = TypeVar("R") @@ -538,6 +551,8 @@ async def handle_cast( ) _context.set(ctx) + DebugContext.set(DebugContext()) + args, kwargs = unpickle(message.message, mailbox) if message.method == "__init__": @@ -574,9 +589,10 @@ async def instrumented(): ) try: result = await the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() except Exception as e: logging.critical( - "Unahndled exception in actor endpoint", + "Unhandled exception in actor endpoint", exc_info=e, ) raise e @@ -589,11 +605,13 @@ async def instrumented(): the_method.__module__, message.method, str(ctx.mailbox.actor_id) ) result = the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() exit_span() if port is not None: port.send("result", result) except Exception as e: + self._post_mortem_debug(e.__traceback__) traceback.print_exc() s = ActorError(e) @@ -604,6 +622,7 @@ async def instrumented(): else: raise s from None except BaseException as e: + self._post_mortem_debug(e.__traceback__) # A BaseException can be thrown in the case of a Rust panic. # In this case, we need a way to signal the panic to the Rust side. # See [Panics in async endpoints] @@ -614,6 +633,29 @@ async def instrumented(): pass raise + def _maybe_exit_debugger(self, do_continue=True) -> None: + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + if do_continue: + pdb_wrapper.clear_all_breaks() + pdb_wrapper.do_continue("") + pdb_wrapper.end_debug_session() + DebugContext.set(DebugContext()) + + def _post_mortem_debug(self, exc_tb) -> None: + from monarch._src.actor.debugger import DebugManager + + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.post_mortem(exc_tb) + self._maybe_exit_debugger(do_continue=False) + def _is_mailbox(x: object) -> bool: return isinstance(x, Mailbox) @@ -648,19 +690,6 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - @endpoint # pyre-ignore - def _set_debug_client(self, client: "DebugClient") -> None: - point = MonarchContext.get().point - # For some reason, using a lambda instead of functools.partial - # confuses the pdb wrapper implementation. - sys.breakpointhook = functools.partial( # pyre-ignore - remote_breakpointhook, - point.rank, - point.shape.coordinates(point.rank), - MonarchContext.get().mailbox.actor_id, - client, - ) - class ActorMeshRef(MeshTrait, Generic[T]): def __init__( diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py index 5b377ac2..b94a8a96 100644 --- a/python/monarch/_src/actor/bootstrap_main.py +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -54,6 +54,10 @@ def invoke_main(): except Exception as e: logging.warning(f"Failed to set up py-spy: {e}") + from monarch._src.actor.debugger import remote_breakpointhook + + sys.breakpointhook = remote_breakpointhook + # Start an event loop for PythonActors to use. asyncio.run(main()) diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index 2906c63f..ed66601e 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -4,23 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio +import functools +import inspect import logging +import os import sys from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import cast, Dict, Generator, List, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, endpoint -from monarch._src.actor.pdb_wrapper import DebuggerWrite -from monarch._src.actor.proc_mesh import local_proc_mesh +from monarch._src.actor.actor_mesh import ( + _ActorMeshRefImpl, + Actor, + ActorMeshRef, + DebugContext, + endpoint, + MonarchContext, +) +from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper from tabulate import tabulate logger = logging.getLogger(__name__) - -CANCEL_TOKEN = object() +_DEBUG_MANAGER_ACTOR_NAME = "debug_manager" async def _debugger_input(prompt=""): @@ -148,93 +157,170 @@ async def debugger_write(self, write: DebuggerWrite) -> None: await self._message_queue.put(("write", write)) +RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] + + +_debug_input_parser = None + + +# Wrap the parser in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_parser(): + global _debug_input_parser + if _debug_input_parser is None: + from lark import Lark + + _debug_input_parser = Lark( + """ + rank_list: INT "," INT ("," INT)* + start: INT? + stop: INT? + step: INT? + rank_range: start ":" stop (":" step)? + dim: CNAME "=" (rank_range | "(" rank_list ")" | INT) + dims: dim ("," dim)* + ranks: "ranks(" (dims | rank_range | rank_list | INT) ")" + pdb_command: /\\w+.*/ + cast: "cast" ranks pdb_command + help: "h" | "help" + attach: ("a" | "attach") INT + cont: "c" | "continue" + quit: "q" | "quit" + list: "l" | "list" + command: attach | list | cast | help | cont | quit + + %import common.INT + %import common.CNAME + %import common.WS + %ignore WS + """, + start="command", + ) + return _debug_input_parser + + +_debug_input_transformer = None + + +# Wrap the transformer in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_transformer(): + global _debug_input_transformer + if _debug_input_transformer is None: + from lark import Transformer + from lark.lexer import Token + + class _IntoDebugCommandTransformer(Transformer): + def rank_list(self, items: List[Token]) -> List[int]: + return [int(item.value) for item in items] + + def start(self, items: List[Token]) -> int: + if len(items) == 0: + return 0 + return int(items[0].value) + + def stop(self, items: List[Token]) -> int: + if len(items) == 0: + return sys.maxsize + return int(items[0].value) + + def step(self, items: List[Token]) -> int: + if len(items) == 0: + return 1 + return int(items[0].value) + + def rank_range(self, items: List[int]) -> range: + return range(*items) + + def dim( + self, items: Tuple[Token, Union[range, List[int], Token]] + ) -> Tuple[str, Union[range, List[int], int]]: + if isinstance(items[1], range): + return (items[0].value, cast(range, items[1])) + elif isinstance(items[1], list): + return (items[0].value, cast(List[int], items[1])) + else: + return (items[0].value, int(cast(Token, items[1]).value)) + + def dims( + self, items: List[Tuple[str, Union[range, List[int], int]]] + ) -> Dict[str, Union[range, List[int], int]]: + return {dim[0]: dim[1] for dim in items} + + def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType: + if isinstance(items[0], Token): + return int(cast(Token, items[0]).value) + return cast(RanksType, items[0]) + + def pdb_command(self, items: List[Token]) -> str: + return items[0].value + + def help(self, _items: List[Token]) -> "Help": + return Help() + + def attach(self, items: List[Token]) -> "Attach": + return Attach(int(items[0].value)) + + def cont(self, _items: List[Token]) -> "Continue": + return Continue() + + def quit(self, _items: List[Token]) -> "Quit": + return Quit() + + def cast(self, items: Tuple[RanksType, str]) -> "Cast": + return Cast(items[0], items[1]) + + def list(self, items: List[Token]) -> "ListCommand": + return ListCommand() + + def command(self, items: List["DebugCommand"]) -> "DebugCommand": + return items[0] + + _debug_input_transformer = _IntoDebugCommandTransformer() + return _debug_input_transformer + + class DebugCommand: @staticmethod def parse(line: str) -> Union["DebugCommand", None]: - parts = line.strip("\n").split(" ") - if len(parts) == 0: + try: + tree = _get_debug_input_parser().parse(line) + return _get_debug_input_transformer().transform(tree) + except Exception as e: + print(f"Error parsing input: {e}") return None - command = parts[0] - match command: - case "attach": - return Attach._parse(parts) - case "list": - return ListCommand() - case "quit": - return Quit() - case "cast": - return Cast._parse(parts) - case "help": - return Help() - case "continue": - return Continue() - case _: - print( - f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help" - ) - return None @dataclass class Attach(DebugCommand): rank: int - @classmethod - def _parse(cls, parts: List[str]) -> "Attach": - if len(parts) != 2: - raise ValueError("Invalid attach command. Expected: attach ") - try: - rank = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid rank {parts[1]}. Expected: int") - return cls(rank) - +@dataclass class ListCommand(DebugCommand): pass +@dataclass class Quit(DebugCommand): pass +@dataclass class Help(DebugCommand): pass +@dataclass class Continue(DebugCommand): pass @dataclass class Cast(DebugCommand): - ranks: List[int] | None + ranks: RanksType command: str - @classmethod - def _parse(cls, parts: List[str]) -> "Cast": - if len(parts) < 3: - raise ValueError( - "Invalid cast command. Expected: cast { | *} " - ) - str_ranks = parts[1] - command = " ".join(parts[2:]) - if str_ranks == "*": - return cls(None, command) - else: - str_ranks = str_ranks.split(",") - if len(str_ranks) == 0: - raise ValueError( - "Invalid rank list for cast. Expected at least one rank." - ) - ranks = [] - for rank in str_ranks: - try: - ranks.append(int(rank)) - except ValueError: - raise ValueError(f"Invalid rank {rank}. Expected: int") - return cls(ranks, command) - class DebugClient(Actor): """ @@ -253,10 +339,10 @@ async def wait_pending_session(self): @endpoint async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]: - table_data = [] + session_info = [] for _, session in self.sessions.items(): info = session.get_info() - table_data.append( + session_info.append( ( info.rank, info.coords, @@ -266,17 +352,30 @@ async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]] info.lineno, ) ) - table_data = sorted(table_data, key=lambda r: r[0]) - - headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."] - print(tabulate(table_data, headers=headers, tablefmt="grid")) - - return table_data + table_info = sorted(session_info, key=lambda r: r[0]) + print( + tabulate( + table_info, + headers=[ + "Rank", + "Coords", + "Hostname", + "Actor ID", + "Function", + "Line No.", + ], + tablefmt="grid", + ) + ) + return table_info @endpoint async def enter(self) -> None: - # pyre-ignore - await getattr(self, "list")._method(self) # noqa + await asyncio.sleep(0.5) + logger.info("Remote breakpoint hit. Entering monarch debugger...") + print("\n\n************************ MONARCH DEBUGGER ************************") + print("Enter 'help' for a list of commands.") + print("Enter 'list' to show all active breakpoints.\n") while True: try: @@ -288,7 +387,10 @@ async def enter(self) -> None: print("\tlist - list all debug sessions") print("\tquit - exit the debugger, leaving all sessions in place") print( - "\tcast { | *} - send a command to a comma-separated list of ranks, or all ranks" + "\tcast ranks(...) - send a command to a set of ranks.\n" + "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n" + "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n" + "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))." ) print( "\tcontinue - tell all ranks to continue execution, then exit the debugger" @@ -300,41 +402,75 @@ async def enter(self) -> None: else: await self.sessions[command.rank].attach() elif isinstance(command, ListCommand): - await getattr(self, "list")._method(self) # noqa + # pyre-ignore + await self.list._method(self) elif isinstance(command, Continue): - # Make sure all ranks have exited their debug sessions. - # If we sent "quit", it would raise BdbQuit, crashing - # the process, which probably isn't what we want. + # Clear all breakpoints and make sure all ranks have + # exited their debug sessions. If we sent "quit", it + # would raise BdbQuit, crashing the process, which + # probably isn't what we want. + await self._cast_input_and_wait("clear") while len(self.sessions) > 0: - tasks = [] - for rank in self.sessions: - tasks.append( - self.sessions[rank].attach("c", suppress_output=True) - ) - await asyncio.gather(*tasks) + await self._cast_input_and_wait("c") return elif isinstance(command, Quit): return elif isinstance(command, Cast): - if command.ranks is None: - ranks = self.sessions.keys() - else: - ranks = command.ranks - tasks = [] - for rank in ranks: - if rank in self.sessions: - tasks.append( - self.sessions[rank].attach( - command.command, - suppress_output=True, - ) - ) - else: - print(f"No debug session for rank {rank}") - await asyncio.gather(*tasks) + await self._cast_input_and_wait(command.command, command.ranks) except Exception as e: print(f"Error processing command: {e}") + async def _cast_input_and_wait( + self, + command: str, + ranks: RanksType | None = None, + ) -> None: + if ranks is None: + ranks = self.sessions.keys() + elif isinstance(ranks, dict): + ranks = self._iter_ranks_dict(ranks) + elif isinstance(ranks, range): + ranks = self._iter_ranks_range(ranks) + elif isinstance(ranks, int): + ranks = [ranks] + tasks = [] + for rank in ranks: + if rank in self.sessions: + tasks.append( + self.sessions[rank].attach( + command, + suppress_output=True, + ) + ) + else: + print(f"No debug session for rank {rank}") + await asyncio.gather(*tasks) + + def _iter_ranks_dict( + self, dims: Dict[str, Union[range, List[int], int]] + ) -> Generator[int, None, None]: + for rank, session in self.sessions.items(): + include_rank = True + for dim, ranks in dims.items(): + if dim not in session.coords: + include_rank = False + break + elif ( + isinstance(ranks, range) or isinstance(ranks, list) + ) and session.coords[dim] not in ranks: + include_rank = False + break + elif isinstance(ranks, int) and session.coords[dim] != ranks: + include_rank = False + break + if include_rank: + yield rank + + def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]: + for rank in self.sessions.keys(): + if rank in rng: + yield rank + ########################################################################## # Debugger APIs # @@ -368,10 +504,57 @@ async def debugger_write(self, rank: int, write: DebuggerWrite) -> None: await session.debugger_write(write) -async def init_debugging( - actor_mesh: ActorMeshRef, -) -> ActorMeshRef[DebugClient]: - debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1) - debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient) - await actor_mesh._set_debug_client.call(debug_client_mesh) - return debug_client_mesh +class DebugManager(Actor): + @staticmethod + @functools.cache + def ref() -> "DebugManager": + ctx = MonarchContext.get() + return cast( + DebugManager, + ActorMeshRef( + DebugManager, + _ActorMeshRefImpl.from_actor_id( + ctx.mailbox, + ActorId.from_string( + f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]" + ), + ), + ctx.mailbox, + ), + ) + + def __init__(self, debug_client: DebugClient) -> None: + self._debug_client = debug_client + + # pyre-ignore + @endpoint + def get_debug_client(self) -> DebugClient: + return self._debug_client + + +def remote_breakpointhook(): + frame = inspect.currentframe() + assert frame is not None + frame = frame.f_back + assert frame is not None + file = frame.f_code.co_filename + line = frame.f_lineno + module = frame.f_globals.get("__name__", "__main__") + if module == "__main__" and not os.path.exists(file): + raise NotImplementedError( + f"Remote debugging not supported for breakpoint at {file}:{line} because " + f"it is defined inside __main__, and the file does not exist on the host. " + "In this case, cloudpickle serialization does not interact nicely with pdb. " + "To debug your code, move it out of __main__ and into a module that " + "exists on both your client and worker processes." + ) + + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.set_trace(frame) diff --git a/python/monarch/_src/actor/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py index 87031dd4..a9cf56b7 100644 --- a/python/monarch/_src/actor/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import bdb import inspect import io @@ -45,35 +46,38 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def setup(self, *args, **kwargs): - r = super().setup(*args, **kwargs) - if self._first: - self._first = False - # when we enter the debugger, we want to present the user's stack frame - # not the nested one inside session.run. This means that the local - # variables are what gets printed, etc. To do this - # we first execute up 2 to get to that frame. - self.do_up(2) - return r - - def set_continue(self) -> None: - r = super().set_continue() - if not self.breaks: - # no more breakpoints so this debugger will not - # be used again, and we detach from the controller io. - self.client_ref.debugger_session_end.call_one(self.rank).get() - # break cycle with itself before we exit - self.stdin = sys.stdin - self.stdout = sys.stdout - return r - - def set_trace(self): + def set_trace(self, frame): self.client_ref.debugger_session_start.call_one( self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id ).get() if self.header: self.message(self.header) - super().set_trace() + super().set_trace(frame) + + def do_clear(self, arg): + if not arg: + # Sending `clear` without any argument specified will + # request confirmation from the user using the `input` function, + # which bypasses our ReadWrapper and causes a hang on the client. + # To avoid this, we just clear all breakpoints instead without + # confirmation. + super().clear_all_breaks() + else: + super().do_clear(arg) + + def end_debug_session(self): + self.client_ref.debugger_session_end.call_one(self.rank).get() + # Once the debug client actor is notified of the session being over, + # we need to prevent any additional requests being sent for the session + # by redirecting stdin and stdout. + self.stdin = sys.stdin + self.stdout = sys.stdout + + def post_mortem(self, exc_tb): + self._first = False + # See builtin implementation of pdb.post_mortem() for reference. + self.reset() + self.interaction(None, exc_tb) class ReadWrapper(io.RawIOBase): @@ -126,10 +130,3 @@ def write(self, s: str): def flush(self): pass - - -def remote_breakpointhook( - rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient" -): - ds = PdbWrapper(rank, coords, actor_id, client_ref) - ds.set_trace() diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..0df67df9 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -38,6 +38,11 @@ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor +from monarch._src.actor.debugger import ( + _DEBUG_MANAGER_ACTOR_NAME, + DebugClient, + DebugManager, +) from monarch._src.actor.device_utils import _local_device_count from monarch._src.actor.future import Future @@ -83,11 +88,13 @@ def __init__( hy_proc_mesh: HyProcMesh, _mock_shape: Optional[Shape] = None, _device_mesh: Optional["DeviceMesh"] = None, + _is_initializing_debugger: bool = False, ) -> None: self._proc_mesh = hy_proc_mesh self._mock_shape: Optional[Shape] = _mock_shape # type: ignore[21] self._rdma_manager: Optional["RDMAManager"] = None + self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._rsync_mesh_client: Optional[RsyncMeshClient] = None self._auto_reload_actor: Optional[AutoReloadActor] = None @@ -96,6 +103,10 @@ def __init__( if _mock_shape is None and HAS_TENSOR_ENGINE: # type: ignore[21] self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + if not _is_initializing_debugger: + self._debug_manager = self._spawn_blocking( + _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client() + ) @property def _shape(self) -> Shape: @@ -296,13 +307,21 @@ async def local_proc_mesh_nonblocking( return await ProcMesh.from_alloc(alloc) -def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: +def local_proc_mesh_blocking( + *, + gpus: Optional[int] = None, + hosts: int = 1, + _is_initializing_debugger: bool = False, +) -> ProcMesh: if gpus is None: gpus = _local_device_count() spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) allocator = LocalAllocator() alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() + return ProcMesh( + HyProcMesh.allocate_blocking(alloc), + _is_initializing_debugger=_is_initializing_debugger, + ) def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: @@ -371,3 +390,33 @@ def proc_mesh( lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), ) + + +_debug_proc_mesh: Optional["ProcMesh"] = None + + +# Lazy init of the debug proc mesh so that importing monarch.proc_mesh +# doesn't trigger the debug client to spawn, which could cause confusing +# logs. This is defined in proc_mesh.py instead of debugger.py for +# circular import reasons. +def _get_debug_proc_mesh() -> "ProcMesh": + global _debug_proc_mesh + if _debug_proc_mesh is None: + _debug_proc_mesh = local_proc_mesh_blocking( + gpus=1, hosts=1, _is_initializing_debugger=True + ) + return _debug_proc_mesh + + +_debug_client_mesh: Optional[ActorMeshRef[DebugClient]] = None + + +# Lazy init for the same reason as above. This is defined in proc_mesh.py +# instead of debugger.py for circular import reasons. +def debug_client() -> ActorMeshRef[DebugClient]: + global _debug_client_mesh + if _debug_client_mesh is None: + _debug_client_mesh = ( + _get_debug_proc_mesh().spawn("debug_client", DebugClient).get() + ) + return _debug_client_mesh diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..0a108a2b 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,13 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + debug_client, + local_proc_mesh, + proc_mesh, + ProcMesh, +) + __all__ = [ "Accumulator", @@ -42,4 +48,5 @@ "ProcMesh", "send", "ValueMesh", + "debug_client", ] diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py new file mode 100644 index 00000000..a4750d70 --- /dev/null +++ b/python/tests/test_debugger.py @@ -0,0 +1,415 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import asyncio +import re +import sys +from unittest.mock import AsyncMock, MagicMock, patch + +import monarch +import monarch.actor as actor + +import pytest + +import torch + +from monarch._src.actor.actor_mesh import Actor, endpoint, MonarchContext +from monarch._src.actor.debugger import ( + Attach, + Cast, + Continue, + DebugClient, + DebugCommand, + DebugSession, + Help, + ListCommand, + Quit, +) + +from monarch._src.actor.proc_mesh import proc_mesh + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def _bad_rank(): + raise ValueError("bad rank") + + +def _debugee_actor_internal(rank): + if rank == 0: + breakpoint() # noqa + rank += 1 + rank += 1 + return rank + elif rank == 1: + breakpoint() # noqa + rank += 2 + rank += 2 + return rank + elif rank == 2: + breakpoint() # noqa + rank += 3 + rank += 3 + _bad_rank() + elif rank == 3: + breakpoint() # noqa + rank += 4 + rank += 4 + return rank + + +class DebugeeActor(Actor): + @endpoint + async def to_debug(self): + rank = MonarchContext.get().point.rank + return _debugee_actor_internal(rank) + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach 1", + "n", + "n", + "n", + "n", + "detach", + "attach 1", + "detach", + "quit", + "cast ranks(0,3) n", + "cast ranks(0,3) n", + # Attaching to 0 and 3 ensures that when we call "list" + # the next time, their function/lineno info will be + # up-to-date. + "attach 0", + "detach", + "attach 3", + "detach", + "quit", + "attach 2", + "c", + "detach", + "quit", + "attach 2", + "bt", + "c", + "quit", + "continue", + ] + + outputs = [] + + def _patch_output(msg): + nonlocal outputs + outputs.append(msg) + + with patch( + "monarch._src.actor.debugger._debugger_input", side_effect=input_mock + ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): + proc = await proc_mesh(hosts=2, gpus=2) + debugee = await proc.spawn("debugee", DebugeeActor) + debug_client = actor.debug_client() + + fut = debugee.to_debug.call() + await debug_client.wait_pending_session.call_one() + breakpoints = [] + for i in range(10): + breakpoints = await debug_client.list.call_one() + if len(breakpoints) == 4: + break + await asyncio.sleep(1) + if i == 9: + raise RuntimeError("timed out waiting for breakpoints") + + initial_linenos = {} + for i in range(len(breakpoints)): + rank, coords, _, _, function, lineno = breakpoints[i] + initial_linenos[rank] = lineno + assert rank == i + assert coords == {"hosts": rank // 2, "gpus": rank % 2} + assert function == "test_debugger._debugee_actor_internal" + assert lineno == breakpoints[0][5] + 5 * rank + + await debug_client.enter.call_one() + + # Check that when detaching and re-attaching to a session, the last portion of the output is repeated + expected_last_output = [ + r"--Return--", + r"\n", + r"> (/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n-> return _debugee_actor_internal\(rank\)", + r"\n", + r"\(Pdb\) ", + ] + output_len = len(expected_last_output) + assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] + for real_output, expected_output in zip( + outputs[-output_len:], expected_last_output + ): + assert re.match(expected_output, real_output) is not None + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i][4] == "test_debugger.to_debug" + elif i in (0, 3): + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + 2 + else: + assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" + assert breakpoints[i][5] == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 4 + # Expect post-mortem debugging for rank 2 + assert breakpoints[2][4] == "test_debugger._bad_rank" + + await debug_client.enter.call_one() + + expected_last_output = [ + r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)", + r"\n", + r'> (/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)', + r"\n", + r"\(Pdb\) ", + ] + + for output, expected_output in zip( + outputs[-len(expected_last_output) :], expected_last_output + ): + assert re.match(expected_output, output) is not None + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 3 + for i, rank in enumerate((0, 1, 3)): + assert breakpoints[i][0] == rank + + await debug_client.enter.call_one() + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 0 + + with pytest.raises( + monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank" + ): + await fut + + +async def test_cast_input_and_wait() -> None: + debug_client = DebugClient() + + mock_sessions = {} + for host in range(3): + for gpu in range(8): + rank = host * 8 + gpu + mock_session = MagicMock(spec=DebugSession) + mock_session.attach = AsyncMock() + mock_session.rank = rank + mock_session.coords = {"hosts": host, "gpus": gpu} + mock_sessions[rank] = mock_session + + debug_client.sessions = mock_sessions + + # Cast to a single rank + await debug_client._cast_input_and_wait("n", 2) + mock_sessions[2].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank != 2: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a list of ranks + ranks = [1, 3, 5] + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to a range of ranks + ranks = range(2, 24, 3) + await debug_client._cast_input_and_wait("n", ranks) + for rank in ranks: + mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) + for rank, session in mock_sessions.items(): + if rank not in ranks: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast to all ranks + await debug_client._cast_input_and_wait("n", None) + for session in mock_sessions.values(): + session.attach.assert_called_once_with("n", suppress_output=True) + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a single value + await debug_client._cast_input_and_wait("n", {"hosts": 1}) + for session in mock_sessions.values(): + if session.coords["hosts"] == 1: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a list + await debug_client._cast_input_and_wait("n", {"hosts": [0, 2]}) + for _rank, session in mock_sessions.items(): + if session.coords["hosts"] in [0, 2]: + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using dimension filtering with a range + await debug_client._cast_input_and_wait("n", {"gpus": range(5, 8)}) + for session in mock_sessions.values(): + if session.coords["gpus"] in range(5, 8): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast using multiple dimension filters + await debug_client._cast_input_and_wait( + "n", {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)} + ) + for session in mock_sessions.values(): + if session.coords["hosts"] in [1, 3] and session.coords["gpus"] in range( + 0, sys.maxsize, 3 + ): + session.attach.assert_called_once_with("n", suppress_output=True) + else: + session.attach.assert_not_called() + + for session in mock_sessions.values(): + session.attach.reset_mock() + + # Cast with non-existent dimension + await debug_client._cast_input_and_wait("n", {"hosts": 0, "gpus": 0, "foo": 0}) + for session in mock_sessions.values(): + session.attach.assert_not_called() + + +@pytest.mark.parametrize( + ["user_input", "expected_output"], + [ + ("attach 1", Attach(1)), + ("a 100", Attach(100)), + ("list", ListCommand()), + ("l", ListCommand()), + ("help", Help()), + ("h", Help()), + ("quit", Quit()), + ("q", Quit()), + ("continue", Continue()), + ("c", Continue()), + ("cast ranks(123) b 25", Cast(ranks=123, command="b 25")), + ("cast ranks(12,34,56) b 25", Cast(ranks=[12, 34, 56], command="b 25")), + ("cast ranks(:) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ("cast ranks(:123) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(123:) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(123:456) b 25", Cast(ranks=range(123, 456), command="b 25")), + ("cast ranks(::) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), + ( + "cast ranks(::123) b 25", + Cast(ranks=range(0, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123::) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), + ("cast ranks(:123:) b 25", Cast(ranks=range(0, 123), command="b 25")), + ("cast ranks(:456:123) b 25", Cast(ranks=range(0, 456, 123), command="b 25")), + ( + "cast ranks(456::123) b 25", + Cast(ranks=range(456, sys.maxsize, 123), command="b 25"), + ), + ("cast ranks(123:456:) b 25", Cast(ranks=range(123, 456), command="b 25")), + ( + "cast ranks(456:789:123) b 25", + Cast(ranks=range(456, 789, 123), command="b 25"), + ), + ("cast ranks(dim1=123) up 2", Cast(ranks={"dim1": 123}, command="up 2")), + ( + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", + Cast( + ranks={ + "dim1": 123, + "dim2": [12, 34, 56], + "dim3": range(15, sys.maxsize, 2), + }, + command="up 2", + ), + ), + ], +) +async def test_debug_command_parser_valid_inputs(user_input, expected_output): + assert DebugCommand.parse(user_input) == expected_output + + +@pytest.mark.parametrize( + "invalid_input", + [ + "", + "attch 1", + "attach", + "cast rnks(123) b 25", + "cast ranks() b 25", + "cast ranks(1ab) b 25", + "cast ranks(1,a,3) b 25", + "cast ranks(a:2:4) b 25", + "cast ranks(1,2,3", + "cast ranks(1,2,3)) b 25", + "cast ranks(1,) b 25", + "cast ranks(1,2,) b 25", + "cast ranks(,1,2) b 25", + "cast ranks(1,,2) b 25", + "cast ranks(:::) b 25", + "cast ranks(:123::) b 25", + "cast ranks(1:2:3,4) b 25", + "cast ranks(dim1=) b 25", + "cast ranks(dim1=123, dim2=) b 25", + "cast ranks(dim1=123, dim2=(12,34,56) b 25", + "cast ranks(dim1=123, dim2=(,12,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,,34,56) b 25", + "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", + "cast ranks(dim1=123,) b 25", + ], +) +async def test_debug_command_parser_invalid_inputs(invalid_input): + assert DebugCommand.parse(invalid_input) is None diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 7092844e..821872e9 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -4,30 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio import operator -import re import threading import time from types import ModuleType -from unittest.mock import AsyncMock, patch import pytest import torch -from monarch._src.actor.debugger import init_debugging -from monarch._src.actor.proc_mesh import local_proc_mesh from monarch.actor import ( Accumulator, Actor, - ActorError, current_actor_name, current_rank, current_size, endpoint, Future, - MonarchContext, + local_proc_mesh, proc_mesh, ) from monarch.rdma import RDMABuffer @@ -408,146 +404,6 @@ def test_proc_mesh_liveness() -> None: counter.value.call().get() -def _debugee_actor_internal(rank): - if rank == 0: - breakpoint() # noqa - rank += 1 - return rank - elif rank == 1: - breakpoint() # noqa - rank += 2 - return rank - elif rank == 2: - breakpoint() # noqa - rank += 3 - raise ValueError("bad rank") - elif rank == 3: - breakpoint() # noqa - rank += 4 - return rank - - -class DebugeeActor(Actor): - @endpoint - async def to_debug(self): - rank = MonarchContext.get().point.rank - return _debugee_actor_internal(rank) - - -async def test_debug() -> None: - input_mock = AsyncMock() - input_mock.side_effect = [ - "attach 1", - "n", - "n", - "n", - "n", - "detach", - "attach 1", - "detach", - "quit", - "cast 0,3 n", - "cast 0,3 n", - # Attaching to 0 and 3 ensures that when we call "list" - # the next time, their function/lineno info will be - # up-to-date. - "attach 0", - "detach", - "attach 3", - "detach", - "quit", - "attach 2", - "c", - "quit", - "continue", - ] - - outputs = [] - - def _patch_output(msg): - nonlocal outputs - outputs.append(msg) - - with patch( - "monarch._src.actor.debugger._debugger_input", side_effect=input_mock - ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): - proc = await proc_mesh(hosts=2, gpus=2) - debugee = await proc.spawn("debugee", DebugeeActor) - debug_client = await init_debugging(debugee) - - fut = debugee.to_debug.call() - await debug_client.wait_pending_session.call_one() - breakpoints = [] - for i in range(10): - breakpoints = await debug_client.list.call_one() - if len(breakpoints) == 4: - break - await asyncio.sleep(1) - if i == 9: - raise RuntimeError("timed out waiting for breakpoints") - - initial_linenos = {} - for i in range(len(breakpoints)): - rank, coords, _, _, function, lineno = breakpoints[i] - initial_linenos[rank] = lineno - assert rank == i - assert coords == {"hosts": rank // 2, "gpus": rank % 2} - assert function == "test_python_actors._debugee_actor_internal" - assert lineno == breakpoints[0][5] + 4 * rank - - await debug_client.enter.call_one() - - # Check that when detaching and re-attaching to a session, the last portion of the output is repeated - expected_last_output = [ - r"--Return--", - r"\n", - r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)", - r"\n", - r"\(Pdb\) ", - ] - output_len = len(expected_last_output) - assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] - for real_output, expected_output in zip( - outputs[-output_len:], expected_last_output - ): - assert re.match(expected_output, real_output) is not None - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - elif i in (0, 3): - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + 2 - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 3 - for i, rank in enumerate((0, 1, 3)): - assert breakpoints[i][0] == rank - - await debug_client.enter.call_one() - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 0 - - with pytest.raises(ActorError, match="ValueError: bad rank"): - await fut - - class TLSActor(Actor): """An actor that manages thread-local state.""" diff --git a/requirements.txt b/requirements.txt index 4235044e..41a2a230 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy pyre-extensions cloudpickle torchx-nightly +lark From 79cf8693594ee45b215236c1bade4dcebd75fb98 Mon Sep 17 00:00:00 2001 From: Sam Lurye Date: Thu, 10 Jul 2025 15:23:54 -0700 Subject: [PATCH 2/2] Support multiple actors on the same rank (#494) Summary: Pull Request resolved: https://github.com/pytorch-labs/monarch/pull/494 Previously, debug sessions were keyed by rank, but this doesn't allow for multiple actors on the same rank. This diff updates the debugging logic so that it is keyed on actor name + rank instead of just rank. The `attach` and `cast` commands now require the user to specify actor name in addition to rank. The `list` command now sorts breakpoints by grouping all actors with the same name (i.e., from the same actor mesh) together, and then sorts by rank within each group. Differential Revision: D78047960 --- python/monarch/_src/actor/debugger.py | 252 ++++++---- python/monarch/_src/actor/pdb_wrapper.py | 14 +- python/tests/test_debugger.py | 584 ++++++++++++++++------- 3 files changed, 574 insertions(+), 276 deletions(-) diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index ed66601e..292ad6d2 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -12,7 +12,7 @@ import os import sys from dataclasses import dataclass -from typing import cast, Dict, Generator, List, Tuple, Union +from typing import cast, Dict, Generator, List, Optional, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId from monarch._src.actor.actor_mesh import ( @@ -43,24 +43,32 @@ def _debugger_output(msg): @dataclass class DebugSessionInfo: + actor_name: str rank: int coords: Dict[str, int] hostname: str - actor_id: ActorId function: str | None lineno: int | None + def __lt__(self, other): + if self.actor_name < other.actor_name: + return True + elif self.actor_name == other.actor_name: + return self.rank < other.rank + else: + return False + class DebugSession: """Represents a single session with a remote debugger.""" def __init__( - self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId + self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str ): self.rank = rank self.coords = coords self.hostname = hostname - self.actor_id = actor_id + self.actor_name = actor_name self._active = False self._message_queue = asyncio.Queue() self._task = None @@ -127,7 +135,7 @@ def get_info(self): if self._function_lineno is not None: function, lineno = self._function_lineno return DebugSessionInfo( - self.rank, self.coords, self.hostname, self.actor_id, function, lineno + self.actor_name, self.rank, self.coords, self.hostname, function, lineno ) async def attach(self, line=None, suppress_output=False): @@ -160,6 +168,97 @@ async def debugger_write(self, write: DebuggerWrite) -> None: RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] +class DebugSessions: + def __init__(self): + self._sessions: Dict[str, Dict[int, DebugSession]] = {} + + def insert(self, session: DebugSession) -> None: + if session.actor_name not in self._sessions: + self._sessions[session.actor_name] = {session.rank: session} + elif session.rank not in self._sessions[session.actor_name]: + self._sessions[session.actor_name][session.rank] = session + else: + raise ValueError( + f"Debug session for rank {session.rank} already exists for actor {session.actor_name}" + ) + + def remove(self, actor_name: str, rank: int) -> DebugSession: + if actor_name not in self._sessions: + raise ValueError(f"No debug sessions for actor {actor_name}") + elif rank not in self._sessions[actor_name]: + raise ValueError(f"No debug session for rank {rank} for actor {actor_name}") + session = self._sessions[actor_name].pop(rank) + if len(self._sessions[actor_name]) == 0: + del self._sessions[actor_name] + return session + + def get(self, actor_name: str, rank: int) -> DebugSession: + if actor_name not in self._sessions: + raise ValueError(f"No debug sessions for actor {actor_name}") + elif rank not in self._sessions[actor_name]: + raise ValueError(f"No debug session for rank {rank} for actor {actor_name}") + return self._sessions[actor_name][rank] + + def iter( + self, selection: Optional[Tuple[str, Optional[RanksType]]] + ) -> Generator[DebugSession, None, None]: + if selection is None: + for sessions in self._sessions.values(): + for session in sessions.values(): + yield session + return + actor_name, ranks = selection + if actor_name not in self._sessions: + return + sessions = self._sessions[actor_name] + if ranks is None: + for session in sessions.values(): + yield session + elif isinstance(ranks, int): + if ranks in sessions: + yield sessions[ranks] + elif isinstance(ranks, list): + for rank in ranks: + if rank in sessions: + yield sessions[rank] + elif isinstance(ranks, dict): + dims = ranks + for session in sessions.values(): + include_rank = True + for dim, ranks in dims.items(): + if dim not in session.coords: + include_rank = False + break + elif ( + isinstance(ranks, range) or isinstance(ranks, list) + ) and session.coords[dim] not in ranks: + include_rank = False + break + elif isinstance(ranks, int) and session.coords[dim] != ranks: + include_rank = False + break + if include_rank: + yield session + elif isinstance(ranks, range): + for rank, session in sessions.items(): + if rank in ranks: + yield session + + def info(self) -> List[DebugSessionInfo]: + session_info = [] + for sessions in self._sessions.values(): + for session in sessions.values(): + session_info.append(session.get_info()) + return session_info + + def __len__(self) -> int: + return sum(len(sessions) for sessions in self._sessions.values()) + + def __contains__(self, item: Tuple[str, int]) -> bool: + actor_name, rank = item + return actor_name in self._sessions and rank in self._sessions[actor_name] + + _debug_input_parser = None @@ -181,14 +280,17 @@ def _get_debug_input_parser(): dims: dim ("," dim)* ranks: "ranks(" (dims | rank_range | rank_list | INT) ")" pdb_command: /\\w+.*/ - cast: "cast" ranks pdb_command + actor_name: /\\w+/ + cast: "cast" _WS actor_name ranks pdb_command help: "h" | "help" - attach: ("a" | "attach") INT + attach: ("a" | "attach") _WS actor_name INT cont: "c" | "continue" quit: "q" | "quit" list: "l" | "list" command: attach | list | cast | help | cont | quit + _WS: WS+ + %import common.INT %import common.CNAME %import common.WS @@ -255,11 +357,14 @@ def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType: def pdb_command(self, items: List[Token]) -> str: return items[0].value + def actor_name(self, items: List[Token]) -> str: + return items[0].value + def help(self, _items: List[Token]) -> "Help": return Help() - def attach(self, items: List[Token]) -> "Attach": - return Attach(int(items[0].value)) + def attach(self, items: Tuple[str, Token]) -> "Attach": + return Attach(items[0], int(items[1].value)) def cont(self, _items: List[Token]) -> "Continue": return Continue() @@ -267,8 +372,8 @@ def cont(self, _items: List[Token]) -> "Continue": def quit(self, _items: List[Token]) -> "Quit": return Quit() - def cast(self, items: Tuple[RanksType, str]) -> "Cast": - return Cast(items[0], items[1]) + def cast(self, items: Tuple[str, RanksType, str]) -> "Cast": + return Cast(*items) def list(self, items: List[Token]) -> "ListCommand": return ListCommand() @@ -293,6 +398,7 @@ def parse(line: str) -> Union["DebugCommand", None]: @dataclass class Attach(DebugCommand): + actor_name: str rank: int @@ -318,6 +424,7 @@ class Continue(DebugCommand): @dataclass class Cast(DebugCommand): + actor_name: str ranks: RanksType command: str @@ -330,7 +437,7 @@ class DebugClient(Actor): """ def __init__(self) -> None: - self.sessions = {} # rank -> DebugSession + self.sessions = DebugSessions() @endpoint async def wait_pending_session(self): @@ -338,36 +445,33 @@ async def wait_pending_session(self): await asyncio.sleep(1) @endpoint - async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]: - session_info = [] - for _, session in self.sessions.items(): - info = session.get_info() - session_info.append( - ( - info.rank, - info.coords, - info.hostname, - info.actor_id, - info.function, - info.lineno, - ) - ) - table_info = sorted(session_info, key=lambda r: r[0]) + async def list(self) -> List[DebugSessionInfo]: + session_info = sorted(self.sessions.info()) print( tabulate( - table_info, + ( + ( + info.actor_name, + info.rank, + info.coords, + info.hostname, + info.function, + info.lineno, + ) + for info in session_info + ), headers=[ + "Actor Name", "Rank", "Coords", "Hostname", - "Actor ID", "Function", "Line No.", ], tablefmt="grid", ) ) - return table_info + return session_info @endpoint async def enter(self) -> None: @@ -380,14 +484,16 @@ async def enter(self) -> None: while True: try: user_input = await _debugger_input("monarch_dbg> ") + if not user_input.strip(): + continue command = DebugCommand.parse(user_input) if isinstance(command, Help): print("monarch_dbg commands:") - print("\tattach - attach to a debug session") + print("\tattach - attach to a debug session") print("\tlist - list all debug sessions") print("\tquit - exit the debugger, leaving all sessions in place") print( - "\tcast ranks(...) - send a command to a set of ranks.\n" + "\tcast ranks(...) - send a command to a set of ranks on the specified actor mesh.\n" "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n" "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n" "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))." @@ -397,10 +503,7 @@ async def enter(self) -> None: ) print("\thelp - print this help message") elif isinstance(command, Attach): - if command.rank not in self.sessions: - print(f"No debug session for rank {command.rank}") - else: - await self.sessions[command.rank].attach() + await self.sessions.get(command.actor_name, command.rank).attach() elif isinstance(command, ListCommand): # pyre-ignore await self.list._method(self) @@ -416,61 +519,22 @@ async def enter(self) -> None: elif isinstance(command, Quit): return elif isinstance(command, Cast): - await self._cast_input_and_wait(command.command, command.ranks) + await self._cast_input_and_wait( + command.command, (command.actor_name, command.ranks) + ) except Exception as e: print(f"Error processing command: {e}") async def _cast_input_and_wait( self, command: str, - ranks: RanksType | None = None, + selection: Optional[Tuple[str, Optional[RanksType]]] = None, ) -> None: - if ranks is None: - ranks = self.sessions.keys() - elif isinstance(ranks, dict): - ranks = self._iter_ranks_dict(ranks) - elif isinstance(ranks, range): - ranks = self._iter_ranks_range(ranks) - elif isinstance(ranks, int): - ranks = [ranks] tasks = [] - for rank in ranks: - if rank in self.sessions: - tasks.append( - self.sessions[rank].attach( - command, - suppress_output=True, - ) - ) - else: - print(f"No debug session for rank {rank}") + for session in self.sessions.iter(selection): + tasks.append(session.attach(command, suppress_output=True)) await asyncio.gather(*tasks) - def _iter_ranks_dict( - self, dims: Dict[str, Union[range, List[int], int]] - ) -> Generator[int, None, None]: - for rank, session in self.sessions.items(): - include_rank = True - for dim, ranks in dims.items(): - if dim not in session.coords: - include_rank = False - break - elif ( - isinstance(ranks, range) or isinstance(ranks, list) - ) and session.coords[dim] not in ranks: - include_rank = False - break - elif isinstance(ranks, int) and session.coords[dim] != ranks: - include_rank = False - break - if include_rank: - yield rank - - def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]: - for rank in self.sessions.keys(): - if rank in rng: - yield rank - ########################################################################## # Debugger APIs # @@ -478,30 +542,30 @@ def _iter_ranks_range(self, rng: range) -> Generator[int, None, None]: # and communicate with them. @endpoint async def debugger_session_start( - self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId + self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str ) -> None: # Create a session if it doesn't exist - if rank not in self.sessions: - self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id) + if (actor_name, rank) not in self.sessions: + self.sessions.insert(DebugSession(rank, coords, hostname, actor_name)) @endpoint - async def debugger_session_end(self, rank: int) -> None: + async def debugger_session_end(self, actor_name: str, rank: int) -> None: """Detach from the current debug session.""" - session = self.sessions.pop(rank) - await session.detach() + await self.sessions.remove(actor_name, rank).detach() @endpoint - async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str: + async def debugger_read( + self, actor_name: str, rank: int, size: int + ) -> DebuggerWrite | str: """Read from the debug session for the given rank.""" - session = self.sessions[rank] - - return await session.debugger_read(size) + return await self.sessions.get(actor_name, rank).debugger_read(size) @endpoint - async def debugger_write(self, rank: int, write: DebuggerWrite) -> None: + async def debugger_write( + self, actor_name: str, rank: int, write: DebuggerWrite + ) -> None: """Write to the debug session for the given rank.""" - session = self.sessions[rank] - await session.debugger_write(write) + await self.sessions.get(actor_name, rank).debugger_write(write) class DebugManager(Actor): diff --git a/python/monarch/_src/actor/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py index a9cf56b7..a0d2015c 100644 --- a/python/monarch/_src/actor/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -46,9 +46,12 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def set_trace(self, frame): + def set_trace(self, frame=None): self.client_ref.debugger_session_start.call_one( - self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id + self.rank, + self.coords, + socket.getfqdn(socket.gethostname()), + self.actor_id.actor_name, ).get() if self.header: self.message(self.header) @@ -66,7 +69,9 @@ def do_clear(self, arg): super().do_clear(arg) def end_debug_session(self): - self.client_ref.debugger_session_end.call_one(self.rank).get() + self.client_ref.debugger_session_end.call_one( + self.actor_id.actor_name, self.rank + ).get() # Once the debug client actor is notified of the session being over, # we need to prevent any additional requests being sent for the session # by redirecting stdin and stdout. @@ -86,7 +91,7 @@ def __init__(self, session: "PdbWrapper"): def readinto(self, b): response = self.session.client_ref.debugger_read.call_one( - self.session.rank, len(b) + self.session.actor_id.actor_name, self.session.rank, len(b) ).get() if response == "detach": # this gets injected by the worker event loop to @@ -120,6 +125,7 @@ def write(self, s: str): # pyre-ignore lineno = self.session.curframe.f_lineno self.session.client_ref.debugger_write.call_one( + self.session.actor_id.actor_name, self.session.rank, DebuggerWrite( s.encode(), diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py index a4750d70..7cd12e47 100644 --- a/python/tests/test_debugger.py +++ b/python/tests/test_debugger.py @@ -8,7 +8,8 @@ import asyncio import re import sys -from unittest.mock import AsyncMock, MagicMock, patch +from typing import cast, List +from unittest.mock import AsyncMock, patch import monarch import monarch.actor as actor @@ -17,14 +18,15 @@ import torch -from monarch._src.actor.actor_mesh import Actor, endpoint, MonarchContext +from monarch._src.actor.actor_mesh import Actor, ActorError, endpoint, MonarchContext from monarch._src.actor.debugger import ( Attach, Cast, Continue, - DebugClient, DebugCommand, DebugSession, + DebugSessionInfo, + DebugSessions, Help, ListCommand, Quit, @@ -72,6 +74,18 @@ async def to_debug(self): return _debugee_actor_internal(rank) +async def _wait_for_breakpoints(debug_client, n_breakpoints) -> List[DebugSessionInfo]: + breakpoints: List[DebugSessionInfo] = [] + for i in range(10): + breakpoints = await debug_client.list.call_one() + if len(breakpoints) == n_breakpoints: + break + await asyncio.sleep(1) + if i == 9: + raise RuntimeError("timed out waiting for breakpoints") + return breakpoints + + @pytest.mark.skipif( torch.cuda.device_count() < 2, reason="Not enough GPUs, this test requires at least 2 GPUs", @@ -79,30 +93,30 @@ async def to_debug(self): async def test_debug() -> None: input_mock = AsyncMock() input_mock.side_effect = [ - "attach 1", + "attach debugee 1", "n", "n", "n", "n", "detach", - "attach 1", + "attach debugee 1", "detach", "quit", - "cast ranks(0,3) n", - "cast ranks(0,3) n", + "cast debugee ranks(0,3) n", + "cast debugee ranks(0,3) n", # Attaching to 0 and 3 ensures that when we call "list" # the next time, their function/lineno info will be # up-to-date. - "attach 0", + "attach debugee 0", "detach", - "attach 3", + "attach debugee 3", "detach", "quit", - "attach 2", + "attach debugee 2", "c", "detach", "quit", - "attach 2", + "attach debugee 2", "bt", "c", "quit", @@ -124,23 +138,16 @@ def _patch_output(msg): fut = debugee.to_debug.call() await debug_client.wait_pending_session.call_one() - breakpoints = [] - for i in range(10): - breakpoints = await debug_client.list.call_one() - if len(breakpoints) == 4: - break - await asyncio.sleep(1) - if i == 9: - raise RuntimeError("timed out waiting for breakpoints") + breakpoints = await _wait_for_breakpoints(debug_client, 4) initial_linenos = {} for i in range(len(breakpoints)): - rank, coords, _, _, function, lineno = breakpoints[i] - initial_linenos[rank] = lineno - assert rank == i - assert coords == {"hosts": rank // 2, "gpus": rank % 2} - assert function == "test_debugger._debugee_actor_internal" - assert lineno == breakpoints[0][5] + 5 * rank + info = breakpoints[i] + initial_linenos[info.rank] = info.lineno + assert info.rank == i + assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2} + assert info.function == "test_debugger._debugee_actor_internal" + assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank await debug_client.enter.call_one() @@ -162,30 +169,36 @@ def _patch_output(msg): breakpoints = await debug_client.list.call_one() for i in range(len(breakpoints)): if i == 1: - assert breakpoints[i][4] == "test_debugger.to_debug" + assert breakpoints[i].function == "test_debugger.to_debug" else: - assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] await debug_client.enter.call_one() breakpoints = await debug_client.list.call_one() for i in range(len(breakpoints)): if i == 1: - assert breakpoints[i][4] == "test_debugger.to_debug" + assert breakpoints[i].function == "test_debugger.to_debug" elif i in (0, 3): - assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + 2 + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] + 2 else: - assert breakpoints[i][4] == "test_debugger._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] await debug_client.enter.call_one() breakpoints = await debug_client.list.call_one() assert len(breakpoints) == 4 # Expect post-mortem debugging for rank 2 - assert breakpoints[2][4] == "test_debugger._bad_rank" + assert breakpoints[2].function == "test_debugger._bad_rank" await debug_client.enter.call_one() @@ -205,7 +218,7 @@ def _patch_output(msg): breakpoints = await debug_client.list.call_one() assert len(breakpoints) == 3 for i, rank in enumerate((0, 1, 3)): - assert breakpoints[i][0] == rank + assert breakpoints[i].rank == rank await debug_client.enter.call_one() breakpoints = await debug_client.list.call_one() @@ -217,122 +230,288 @@ def _patch_output(msg): await fut -async def test_cast_input_and_wait() -> None: - debug_client = DebugClient() +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug_multi_actor() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach debugee_2 2", + "n", + "detach", + "attach debugee_1 1", + "n", + "detach", + "quit", + "cast debugee_1 ranks(:) c", + "cast debugee_2 ranks(:) c", + "attach debugee_2 2", + "c", + "quit", + "continue", + ] + + with patch("monarch._src.actor.debugger._debugger_input", side_effect=input_mock): + proc = await proc_mesh(hosts=2, gpus=2) + debugee_1 = await proc.spawn("debugee_1", DebugeeActor) + debugee_2 = await proc.spawn("debugee_2", DebugeeActor) + debug_client = actor.debug_client() - mock_sessions = {} - for host in range(3): - for gpu in range(8): - rank = host * 8 + gpu - mock_session = MagicMock(spec=DebugSession) - mock_session.attach = AsyncMock() - mock_session.rank = rank - mock_session.coords = {"hosts": host, "gpus": gpu} - mock_sessions[rank] = mock_session + fut_1 = debugee_1.to_debug.call() + fut_2 = debugee_2.to_debug.call() + await debug_client.wait_pending_session.call_one() - debug_client.sessions = mock_sessions + breakpoints = await _wait_for_breakpoints(debug_client, 8) - # Cast to a single rank - await debug_client._cast_input_and_wait("n", 2) - mock_sessions[2].attach.assert_called_once_with("n", suppress_output=True) - for rank, session in mock_sessions.items(): - if rank != 2: - session.attach.assert_not_called() + initial_linenos = {} + for i in range(len(breakpoints)): + info = breakpoints[i] + initial_linenos[info.rank] = info.lineno + assert info.rank == i % 4 + assert info.actor_name == "debugee_1" if i < 4 else "debugee_2" + assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2} + assert info.function == "test_debugger._debugee_actor_internal" + assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank - for session in mock_sessions.values(): - session.attach.reset_mock() + await debug_client.enter.call_one() - # Cast to a list of ranks - ranks = [1, 3, 5] - await debug_client._cast_input_and_wait("n", ranks) - for rank in ranks: - mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) - for rank, session in mock_sessions.items(): - if rank not in ranks: - session.attach.assert_not_called() + breakpoints = await _wait_for_breakpoints(debug_client, 8) + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i].actor_name == "debugee_1" + assert breakpoints[i].rank == 1 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 + elif i == 6: + assert breakpoints[i].actor_name == "debugee_2" + assert breakpoints[i].rank == 2 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 + else: + assert ( + breakpoints[i].actor_name == "debugee_1" if i < 4 else "debugee_2" + ) + assert breakpoints[i].rank == i % 4 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] - for session in mock_sessions.values(): - session.attach.reset_mock() + await debug_client.enter.call_one() - # Cast to a range of ranks - ranks = range(2, 24, 3) - await debug_client._cast_input_and_wait("n", ranks) - for rank in ranks: - mock_sessions[rank].attach.assert_called_once_with("n", suppress_output=True) - for rank, session in mock_sessions.items(): - if rank not in ranks: - session.attach.assert_not_called() - - for session in mock_sessions.values(): - session.attach.reset_mock() - - # Cast to all ranks - await debug_client._cast_input_and_wait("n", None) - for session in mock_sessions.values(): - session.attach.assert_called_once_with("n", suppress_output=True) - - for session in mock_sessions.values(): - session.attach.reset_mock() - - # Cast using dimension filtering with a single value - await debug_client._cast_input_and_wait("n", {"hosts": 1}) - for session in mock_sessions.values(): - if session.coords["hosts"] == 1: - session.attach.assert_called_once_with("n", suppress_output=True) - else: - session.attach.assert_not_called() - - for session in mock_sessions.values(): - session.attach.reset_mock() - - # Cast using dimension filtering with a list - await debug_client._cast_input_and_wait("n", {"hosts": [0, 2]}) - for _rank, session in mock_sessions.items(): - if session.coords["hosts"] in [0, 2]: - session.attach.assert_called_once_with("n", suppress_output=True) - else: - session.attach.assert_not_called() - - for session in mock_sessions.values(): - session.attach.reset_mock() - - # Cast using dimension filtering with a range - await debug_client._cast_input_and_wait("n", {"gpus": range(5, 8)}) - for session in mock_sessions.values(): - if session.coords["gpus"] in range(5, 8): - session.attach.assert_called_once_with("n", suppress_output=True) - else: - session.attach.assert_not_called() - - for session in mock_sessions.values(): - session.attach.reset_mock() - - # Cast using multiple dimension filters - await debug_client._cast_input_and_wait( - "n", {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)} - ) - for session in mock_sessions.values(): - if session.coords["hosts"] in [1, 3] and session.coords["gpus"] in range( - 0, sys.maxsize, 3 - ): - session.attach.assert_called_once_with("n", suppress_output=True) - else: - session.attach.assert_not_called() + breakpoints = await _wait_for_breakpoints(debug_client, 1) + with pytest.raises(ActorError, match="ValueError: bad rank"): + await fut_2 + assert breakpoints[0].actor_name == "debugee_1" + assert breakpoints[0].rank == 2 + assert breakpoints[0].function == "test_debugger._bad_rank" - for session in mock_sessions.values(): - session.attach.reset_mock() + await debug_client.enter.call_one() - # Cast with non-existent dimension - await debug_client._cast_input_and_wait("n", {"hosts": 0, "gpus": 0, "foo": 0}) - for session in mock_sessions.values(): - session.attach.assert_not_called() + breakpoints = await _wait_for_breakpoints(debug_client, 0) + with pytest.raises(ActorError, match="ValueError: bad rank"): + await fut_1 + + +async def test_debug_sessions_insert_get_remove() -> None: + mock_sessions = [] + for actor_name in ("actor_a", "actor_b"): + for rank in range(2): + mock_session = DebugSession(rank, {}, "", actor_name) + mock_sessions.append(mock_session) + + debug_sessions = DebugSessions() + + with pytest.raises(ValueError, match="No debug sessions for actor actor_a"): + debug_sessions.get("actor_a", 0) + debug_sessions.insert(mock_sessions[0]) + assert debug_sessions.get("actor_a", 0) is mock_sessions[0] + assert ("actor_a", 0) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 0 already exists for actor actor_a" + ): + debug_sessions.insert(mock_sessions[0]) + + with pytest.raises( + ValueError, match="No debug session for rank 1 for actor actor_a" + ): + debug_sessions.get("actor_a", 1) + debug_sessions.insert(mock_sessions[1]) + assert debug_sessions.get("actor_a", 1) is mock_sessions[1] + assert ("actor_a", 1) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 1 already exists for actor actor_a" + ): + debug_sessions.insert(mock_sessions[1]) + + with pytest.raises(ValueError, match="No debug sessions for actor actor_b"): + debug_sessions.get("actor_b", 0) + debug_sessions.insert(mock_sessions[2]) + assert debug_sessions.get("actor_b", 0) is mock_sessions[2] + assert ("actor_b", 0) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 0 already exists for actor actor_b" + ): + debug_sessions.insert(mock_sessions[2]) + + with pytest.raises( + ValueError, match="No debug session for rank 1 for actor actor_b" + ): + debug_sessions.get("actor_b", 1) + debug_sessions.insert(mock_sessions[3]) + assert debug_sessions.get("actor_b", 1) is mock_sessions[3] + assert ("actor_b", 1) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 1 already exists for actor actor_b" + ): + debug_sessions.insert(mock_sessions[3]) + + assert len(debug_sessions) == 4 + + assert debug_sessions.remove("actor_a", 0) is mock_sessions[0] + assert len(debug_sessions) == 3 + assert ("actor_a", 0) not in debug_sessions + with pytest.raises( + ValueError, match="No debug session for rank 0 for actor actor_a" + ): + debug_sessions.remove("actor_a", 0) + + assert debug_sessions.remove("actor_a", 1) is mock_sessions[1] + assert len(debug_sessions) == 2 + assert ("actor_a", 1) not in debug_sessions + with pytest.raises(ValueError, match="No debug sessions for actor actor_a"): + debug_sessions.remove("actor_a", 1) + + assert debug_sessions.remove("actor_b", 0) is mock_sessions[2] + assert len(debug_sessions) == 1 + assert ("actor_b", 0) not in debug_sessions + with pytest.raises( + ValueError, match="No debug session for rank 0 for actor actor_b" + ): + debug_sessions.remove("actor_b", 0) + + assert debug_sessions.remove("actor_b", 1) is mock_sessions[3] + assert len(debug_sessions) == 0 + assert ("actor_b", 1) not in debug_sessions + with pytest.raises(ValueError, match="No debug sessions for actor actor_b"): + debug_sessions.remove("actor_b", 1) + + +async def test_debug_sessions_iter() -> None: + debug_sessions = DebugSessions() + mock_sessions = [] + + for actor_name in ("actor_a", "actor_b"): + for host in range(3): + for gpu in range(8): + rank = host * 8 + gpu + mock_session = DebugSession( + rank, {"hosts": host, "gpus": gpu}, "", actor_name + ) + mock_sessions.append(mock_session) + debug_sessions.insert(mock_session) + + # Single rank + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = list(debug_sessions.iter((actor_name, 2))) + assert len(sessions) == 1 + assert sessions[0] is mock_sessions[i * 24 + 2] + + # List of ranks + ranks = [1, 3, 5] + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, ranks)), key=lambda s: s.get_info() + ) + assert len(sessions) == 3 + for j in range(3): + assert sessions[j] is mock_sessions[i * 24 + ranks[j]] + + # Range of ranks + ranks = range(2, 24, 3) + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, ranks)), key=lambda s: s.get_info() + ) + ranks = list(ranks) + assert len(sessions) == len(ranks) + for j in range(len(ranks)): + assert sessions[j] is mock_sessions[i * 24 + ranks[j]] + + # All ranks + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, None)), key=lambda s: s.get_info() + ) + assert len(sessions) == 24 + for j in range(24): + assert sessions[j] is mock_sessions[i * 24 + j] + + # All ranks, all actors + sessions = sorted(debug_sessions.iter(None), key=lambda s: s.get_info()) + assert len(sessions) == 48 + for i in range(48): + assert sessions[i] is mock_sessions[i] + + # Dimension filtering with a single value + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": 1})), key=lambda s: s.get_info() + ) + assert len(sessions) == 8 + for j in range(8): + assert sessions[j] is mock_sessions[i * 24 + 8 + j] + + # Dimension filtering with a list + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": [0, 2]})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 16 + j = 0 + for host in (0, 2): + for gpu in range(8): + assert sessions[j] is mock_sessions[i * 24 + host * 8 + gpu] + j += 1 + + # Dimension filtering with a range + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"gpus": range(5, 8)})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 9 + j = 0 + for host in range(3): + for gpu in range(5, 8): + assert sessions[j] is mock_sessions[i * 24 + host * 8 + gpu] + j += 1 + + # Multiple dimension filters + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter( + (actor_name, {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)}) + ), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 3 + j = 0 + for gpu in range(0, 8, 3): + assert sessions[j] is mock_sessions[i * 24 + 8 + gpu] + j += 1 + + # Non-existent dimension + for actor_name in ("actor_a", "actor_b"): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": 0, "gpus": 0, "foo": 0})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 0 @pytest.mark.parametrize( ["user_input", "expected_output"], [ - ("attach 1", Attach(1)), - ("a 100", Attach(100)), + ("attach debugee 1", Attach("debugee", 1)), + ("a my_awesome_actor 100", Attach("my_awesome_actor", 100)), ("list", ListCommand()), ("l", ListCommand()), ("help", Help()), @@ -341,33 +520,74 @@ async def test_cast_input_and_wait() -> None: ("q", Quit()), ("continue", Continue()), ("c", Continue()), - ("cast ranks(123) b 25", Cast(ranks=123, command="b 25")), - ("cast ranks(12,34,56) b 25", Cast(ranks=[12, 34, 56], command="b 25")), - ("cast ranks(:) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), - ("cast ranks(:123) b 25", Cast(ranks=range(0, 123), command="b 25")), - ("cast ranks(123:) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), - ("cast ranks(123:456) b 25", Cast(ranks=range(123, 456), command="b 25")), - ("cast ranks(::) b 25", Cast(ranks=range(0, sys.maxsize), command="b 25")), ( - "cast ranks(::123) b 25", - Cast(ranks=range(0, sys.maxsize, 123), command="b 25"), + "cast debugee ranks(123) b 25", + Cast(actor_name="debugee", ranks=123, command="b 25"), + ), + ( + "cast my_awesome_actor ranks(12,34,56) b 25", + Cast(actor_name="my_awesome_actor", ranks=[12, 34, 56], command="b 25"), + ), + ( + "cast debugee ranks(:) b 25", + Cast(actor_name="debugee", ranks=range(0, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(:123) b 25", + Cast(actor_name="debugee", ranks=range(0, 123), command="b 25"), + ), + ( + "cast debugee ranks(123:) b 25", + Cast(actor_name="debugee", ranks=range(123, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(123:456) b 25", + Cast(actor_name="debugee", ranks=range(123, 456), command="b 25"), + ), + ( + "cast debugee ranks(::) b 25", + Cast(actor_name="debugee", ranks=range(0, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(::123) b 25", + Cast( + actor_name="debugee", ranks=range(0, sys.maxsize, 123), command="b 25" + ), + ), + ( + "cast debugee ranks(123::) b 25", + Cast(actor_name="debugee", ranks=range(123, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(:123:) b 25", + Cast(actor_name="debugee", ranks=range(0, 123), command="b 25"), + ), + ( + "cast debugee ranks(:456:123) b 25", + Cast(actor_name="debugee", ranks=range(0, 456, 123), command="b 25"), + ), + ( + "cast debugee ranks(456::123) b 25", + Cast( + actor_name="debugee", ranks=range(456, sys.maxsize, 123), command="b 25" + ), + ), + ( + "cast debugee ranks(123:456:) b 25", + Cast(actor_name="debugee", ranks=range(123, 456), command="b 25"), ), - ("cast ranks(123::) b 25", Cast(ranks=range(123, sys.maxsize), command="b 25")), - ("cast ranks(:123:) b 25", Cast(ranks=range(0, 123), command="b 25")), - ("cast ranks(:456:123) b 25", Cast(ranks=range(0, 456, 123), command="b 25")), ( - "cast ranks(456::123) b 25", - Cast(ranks=range(456, sys.maxsize, 123), command="b 25"), + "cast debugee ranks(456:789:123) b 25", + Cast(actor_name="debugee", ranks=range(456, 789, 123), command="b 25"), ), - ("cast ranks(123:456:) b 25", Cast(ranks=range(123, 456), command="b 25")), ( - "cast ranks(456:789:123) b 25", - Cast(ranks=range(456, 789, 123), command="b 25"), + "cast debugee ranks(dim1=123) up 2", + Cast(actor_name="debugee", ranks={"dim1": 123}, command="up 2"), ), - ("cast ranks(dim1=123) up 2", Cast(ranks={"dim1": 123}, command="up 2")), ( - "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", + "cast debugee ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", Cast( + actor_name="debugee", ranks={ "dim1": 123, "dim2": [12, 34, 56], @@ -386,29 +606,37 @@ async def test_debug_command_parser_valid_inputs(user_input, expected_output): "invalid_input", [ "", - "attch 1", + "a", "attach", - "cast rnks(123) b 25", - "cast ranks() b 25", - "cast ranks(1ab) b 25", - "cast ranks(1,a,3) b 25", - "cast ranks(a:2:4) b 25", - "cast ranks(1,2,3", - "cast ranks(1,2,3)) b 25", - "cast ranks(1,) b 25", - "cast ranks(1,2,) b 25", - "cast ranks(,1,2) b 25", - "cast ranks(1,,2) b 25", - "cast ranks(:::) b 25", - "cast ranks(:123::) b 25", - "cast ranks(1:2:3,4) b 25", - "cast ranks(dim1=) b 25", - "cast ranks(dim1=123, dim2=) b 25", - "cast ranks(dim1=123, dim2=(12,34,56) b 25", - "cast ranks(dim1=123, dim2=(,12,34,56) b 25", - "cast ranks(dim1=123, dim2=(12,,34,56) b 25", - "cast ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", - "cast ranks(dim1=123,) b 25", + "a actor", + "attach actor", + "attacha actor 1" "attch actor 1", + "attach actor 1abc", + "attach actor 1 a", + "cast ranks(123) b 25", + "cast ranks(123) b 25", + "castactor ranks(123) b 25", + "cast actor rnks(123) b 25", + "cast actor ranks() b 25", + "cast actor ranks(1ab) b 25", + "cast actor ranks(1,a,3) b 25", + "cast actor ranks(a:2:4) b 25", + "cast actor ranks(1,2,3", + "cast actor ranks(1,2,3)) b 25", + "cast actor ranks(1,) b 25", + "cast actor ranks(1,2,) b 25", + "cast actor ranks(,1,2) b 25", + "cast actor ranks(1,,2) b 25", + "cast actor ranks(:::) b 25", + "cast actor ranks(:123::) b 25", + "cast actor ranks(1:2:3,4) b 25", + "cast actor ranks(dim1=) b 25", + "cast actor ranks(dim1=123, dim2=) b 25", + "cast actor ranks(dim1=123, dim2=(12,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(,12,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(12,,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", + "cast actor ranks(dim1=123,) b 25", ], ) async def test_debug_command_parser_invalid_inputs(invalid_input):