diff --git a/python/monarch/_src/actor/actor_mesh.py b/python/monarch/_src/actor/actor_mesh.py index f22c9646..49884121 100644 --- a/python/monarch/_src/actor/actor_mesh.py +++ b/python/monarch/_src/actor/actor_mesh.py @@ -12,7 +12,6 @@ import inspect import logging import random -import sys import traceback from dataclasses import dataclass @@ -53,15 +52,12 @@ from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.future import Future -from monarch._src.actor.pdb_wrapper import remote_breakpointhook +from monarch._src.actor.pdb_wrapper import PdbWrapper from monarch._src.actor.pickle import flatten, unpickle from monarch._src.actor.shape import MeshTrait, NDSlice -if TYPE_CHECKING: - from monarch._src.actor.debugger import DebugClient - logger: logging.Logger = logging.getLogger(__name__) Allocator = ProcessAllocator | LocalAllocator @@ -97,6 +93,23 @@ def get() -> "MonarchContext": ) +@dataclass +class DebugContext: + pdb_wrapper: Optional[PdbWrapper] = None + + @staticmethod + def get() -> "DebugContext": + return _debug_context.get() + + @staticmethod + def set(debug_context: "DebugContext") -> None: + _debug_context.set(debug_context) + + +_debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar( + "monarch.actor_mesh._debug_context" +) + T = TypeVar("T") P = ParamSpec("P") R = TypeVar("R") @@ -538,6 +551,8 @@ async def handle_cast( ) _context.set(ctx) + DebugContext.set(DebugContext()) + args, kwargs = unpickle(message.message, mailbox) if message.method == "__init__": @@ -574,9 +589,10 @@ async def instrumented(): ) try: result = await the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() except Exception as e: logging.critical( - "Unahndled exception in actor endpoint", + "Unhandled exception in actor endpoint", exc_info=e, ) raise e @@ -589,11 +605,13 @@ async def instrumented(): the_method.__module__, message.method, str(ctx.mailbox.actor_id) ) result = the_method(self.instance, *args, **kwargs) + self._maybe_exit_debugger() exit_span() if port is not None: port.send("result", result) except Exception as e: + self._post_mortem_debug(e.__traceback__) traceback.print_exc() s = ActorError(e) @@ -604,6 +622,7 @@ async def instrumented(): else: raise s from None except BaseException as e: + self._post_mortem_debug(e.__traceback__) # A BaseException can be thrown in the case of a Rust panic. # In this case, we need a way to signal the panic to the Rust side. # See [Panics in async endpoints] @@ -614,6 +633,29 @@ async def instrumented(): pass raise + def _maybe_exit_debugger(self, do_continue=True) -> None: + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + if do_continue: + pdb_wrapper.clear_all_breaks() + pdb_wrapper.do_continue("") + pdb_wrapper.end_debug_session() + DebugContext.set(DebugContext()) + + def _post_mortem_debug(self, exc_tb) -> None: + from monarch._src.actor.debugger import DebugManager + + if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None: + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.post_mortem(exc_tb) + self._maybe_exit_debugger(do_continue=False) + def _is_mailbox(x: object) -> bool: return isinstance(x, Mailbox) @@ -648,19 +690,6 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef": "actor implementations are not meshes, but we can't convince the typechecker of it..." ) - @endpoint # pyre-ignore - def _set_debug_client(self, client: "DebugClient") -> None: - point = MonarchContext.get().point - # For some reason, using a lambda instead of functools.partial - # confuses the pdb wrapper implementation. - sys.breakpointhook = functools.partial( # pyre-ignore - remote_breakpointhook, - point.rank, - point.shape.coordinates(point.rank), - MonarchContext.get().mailbox.actor_id, - client, - ) - class ActorMeshRef(MeshTrait, Generic[T]): def __init__( diff --git a/python/monarch/_src/actor/bootstrap_main.py b/python/monarch/_src/actor/bootstrap_main.py index 5b377ac2..b94a8a96 100644 --- a/python/monarch/_src/actor/bootstrap_main.py +++ b/python/monarch/_src/actor/bootstrap_main.py @@ -54,6 +54,10 @@ def invoke_main(): except Exception as e: logging.warning(f"Failed to set up py-spy: {e}") + from monarch._src.actor.debugger import remote_breakpointhook + + sys.breakpointhook = remote_breakpointhook + # Start an event loop for PythonActors to use. asyncio.run(main()) diff --git a/python/monarch/_src/actor/debugger.py b/python/monarch/_src/actor/debugger.py index 2906c63f..292ad6d2 100644 --- a/python/monarch/_src/actor/debugger.py +++ b/python/monarch/_src/actor/debugger.py @@ -4,23 +4,32 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio +import functools +import inspect import logging +import os import sys from dataclasses import dataclass -from typing import Dict, List, Tuple, Union +from typing import cast, Dict, Generator, List, Optional, Tuple, Union from monarch._rust_bindings.monarch_hyperactor.proc import ActorId -from monarch._src.actor.actor_mesh import Actor, ActorMeshRef, endpoint -from monarch._src.actor.pdb_wrapper import DebuggerWrite -from monarch._src.actor.proc_mesh import local_proc_mesh +from monarch._src.actor.actor_mesh import ( + _ActorMeshRefImpl, + Actor, + ActorMeshRef, + DebugContext, + endpoint, + MonarchContext, +) +from monarch._src.actor.pdb_wrapper import DebuggerWrite, PdbWrapper from tabulate import tabulate logger = logging.getLogger(__name__) - -CANCEL_TOKEN = object() +_DEBUG_MANAGER_ACTOR_NAME = "debug_manager" async def _debugger_input(prompt=""): @@ -34,24 +43,32 @@ def _debugger_output(msg): @dataclass class DebugSessionInfo: + actor_name: str rank: int coords: Dict[str, int] hostname: str - actor_id: ActorId function: str | None lineno: int | None + def __lt__(self, other): + if self.actor_name < other.actor_name: + return True + elif self.actor_name == other.actor_name: + return self.rank < other.rank + else: + return False + class DebugSession: """Represents a single session with a remote debugger.""" def __init__( - self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId + self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str ): self.rank = rank self.coords = coords self.hostname = hostname - self.actor_id = actor_id + self.actor_name = actor_name self._active = False self._message_queue = asyncio.Queue() self._task = None @@ -118,7 +135,7 @@ def get_info(self): if self._function_lineno is not None: function, lineno = self._function_lineno return DebugSessionInfo( - self.rank, self.coords, self.hostname, self.actor_id, function, lineno + self.actor_name, self.rank, self.coords, self.hostname, function, lineno ) async def attach(self, line=None, suppress_output=False): @@ -148,93 +165,269 @@ async def debugger_write(self, write: DebuggerWrite) -> None: await self._message_queue.put(("write", write)) +RanksType = Union[int, List[int], range, Dict[str, Union[range, List[int], int]]] + + +class DebugSessions: + def __init__(self): + self._sessions: Dict[str, Dict[int, DebugSession]] = {} + + def insert(self, session: DebugSession) -> None: + if session.actor_name not in self._sessions: + self._sessions[session.actor_name] = {session.rank: session} + elif session.rank not in self._sessions[session.actor_name]: + self._sessions[session.actor_name][session.rank] = session + else: + raise ValueError( + f"Debug session for rank {session.rank} already exists for actor {session.actor_name}" + ) + + def remove(self, actor_name: str, rank: int) -> DebugSession: + if actor_name not in self._sessions: + raise ValueError(f"No debug sessions for actor {actor_name}") + elif rank not in self._sessions[actor_name]: + raise ValueError(f"No debug session for rank {rank} for actor {actor_name}") + session = self._sessions[actor_name].pop(rank) + if len(self._sessions[actor_name]) == 0: + del self._sessions[actor_name] + return session + + def get(self, actor_name: str, rank: int) -> DebugSession: + if actor_name not in self._sessions: + raise ValueError(f"No debug sessions for actor {actor_name}") + elif rank not in self._sessions[actor_name]: + raise ValueError(f"No debug session for rank {rank} for actor {actor_name}") + return self._sessions[actor_name][rank] + + def iter( + self, selection: Optional[Tuple[str, Optional[RanksType]]] + ) -> Generator[DebugSession, None, None]: + if selection is None: + for sessions in self._sessions.values(): + for session in sessions.values(): + yield session + return + actor_name, ranks = selection + if actor_name not in self._sessions: + return + sessions = self._sessions[actor_name] + if ranks is None: + for session in sessions.values(): + yield session + elif isinstance(ranks, int): + if ranks in sessions: + yield sessions[ranks] + elif isinstance(ranks, list): + for rank in ranks: + if rank in sessions: + yield sessions[rank] + elif isinstance(ranks, dict): + dims = ranks + for session in sessions.values(): + include_rank = True + for dim, ranks in dims.items(): + if dim not in session.coords: + include_rank = False + break + elif ( + isinstance(ranks, range) or isinstance(ranks, list) + ) and session.coords[dim] not in ranks: + include_rank = False + break + elif isinstance(ranks, int) and session.coords[dim] != ranks: + include_rank = False + break + if include_rank: + yield session + elif isinstance(ranks, range): + for rank, session in sessions.items(): + if rank in ranks: + yield session + + def info(self) -> List[DebugSessionInfo]: + session_info = [] + for sessions in self._sessions.values(): + for session in sessions.values(): + session_info.append(session.get_info()) + return session_info + + def __len__(self) -> int: + return sum(len(sessions) for sessions in self._sessions.values()) + + def __contains__(self, item: Tuple[str, int]) -> bool: + actor_name, rank = item + return actor_name in self._sessions and rank in self._sessions[actor_name] + + +_debug_input_parser = None + + +# Wrap the parser in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_parser(): + global _debug_input_parser + if _debug_input_parser is None: + from lark import Lark + + _debug_input_parser = Lark( + """ + rank_list: INT "," INT ("," INT)* + start: INT? + stop: INT? + step: INT? + rank_range: start ":" stop (":" step)? + dim: CNAME "=" (rank_range | "(" rank_list ")" | INT) + dims: dim ("," dim)* + ranks: "ranks(" (dims | rank_range | rank_list | INT) ")" + pdb_command: /\\w+.*/ + actor_name: /\\w+/ + cast: "cast" _WS actor_name ranks pdb_command + help: "h" | "help" + attach: ("a" | "attach") _WS actor_name INT + cont: "c" | "continue" + quit: "q" | "quit" + list: "l" | "list" + command: attach | list | cast | help | cont | quit + + _WS: WS+ + + %import common.INT + %import common.CNAME + %import common.WS + %ignore WS + """, + start="command", + ) + return _debug_input_parser + + +_debug_input_transformer = None + + +# Wrap the transformer in a function so that jobs don't have to import lark +# unless they want to use the debugger. +def _get_debug_input_transformer(): + global _debug_input_transformer + if _debug_input_transformer is None: + from lark import Transformer + from lark.lexer import Token + + class _IntoDebugCommandTransformer(Transformer): + def rank_list(self, items: List[Token]) -> List[int]: + return [int(item.value) for item in items] + + def start(self, items: List[Token]) -> int: + if len(items) == 0: + return 0 + return int(items[0].value) + + def stop(self, items: List[Token]) -> int: + if len(items) == 0: + return sys.maxsize + return int(items[0].value) + + def step(self, items: List[Token]) -> int: + if len(items) == 0: + return 1 + return int(items[0].value) + + def rank_range(self, items: List[int]) -> range: + return range(*items) + + def dim( + self, items: Tuple[Token, Union[range, List[int], Token]] + ) -> Tuple[str, Union[range, List[int], int]]: + if isinstance(items[1], range): + return (items[0].value, cast(range, items[1])) + elif isinstance(items[1], list): + return (items[0].value, cast(List[int], items[1])) + else: + return (items[0].value, int(cast(Token, items[1]).value)) + + def dims( + self, items: List[Tuple[str, Union[range, List[int], int]]] + ) -> Dict[str, Union[range, List[int], int]]: + return {dim[0]: dim[1] for dim in items} + + def ranks(self, items: List[Union[RanksType, Token]]) -> RanksType: + if isinstance(items[0], Token): + return int(cast(Token, items[0]).value) + return cast(RanksType, items[0]) + + def pdb_command(self, items: List[Token]) -> str: + return items[0].value + + def actor_name(self, items: List[Token]) -> str: + return items[0].value + + def help(self, _items: List[Token]) -> "Help": + return Help() + + def attach(self, items: Tuple[str, Token]) -> "Attach": + return Attach(items[0], int(items[1].value)) + + def cont(self, _items: List[Token]) -> "Continue": + return Continue() + + def quit(self, _items: List[Token]) -> "Quit": + return Quit() + + def cast(self, items: Tuple[str, RanksType, str]) -> "Cast": + return Cast(*items) + + def list(self, items: List[Token]) -> "ListCommand": + return ListCommand() + + def command(self, items: List["DebugCommand"]) -> "DebugCommand": + return items[0] + + _debug_input_transformer = _IntoDebugCommandTransformer() + return _debug_input_transformer + + class DebugCommand: @staticmethod def parse(line: str) -> Union["DebugCommand", None]: - parts = line.strip("\n").split(" ") - if len(parts) == 0: + try: + tree = _get_debug_input_parser().parse(line) + return _get_debug_input_transformer().transform(tree) + except Exception as e: + print(f"Error parsing input: {e}") return None - command = parts[0] - match command: - case "attach": - return Attach._parse(parts) - case "list": - return ListCommand() - case "quit": - return Quit() - case "cast": - return Cast._parse(parts) - case "help": - return Help() - case "continue": - return Continue() - case _: - print( - f"Unknown command {command}. Expected: attach | list | quit | cast | continue | help" - ) - return None @dataclass class Attach(DebugCommand): + actor_name: str rank: int - @classmethod - def _parse(cls, parts: List[str]) -> "Attach": - if len(parts) != 2: - raise ValueError("Invalid attach command. Expected: attach ") - try: - rank = int(parts[1]) - except ValueError: - raise ValueError(f"Invalid rank {parts[1]}. Expected: int") - return cls(rank) - +@dataclass class ListCommand(DebugCommand): pass +@dataclass class Quit(DebugCommand): pass +@dataclass class Help(DebugCommand): pass +@dataclass class Continue(DebugCommand): pass @dataclass class Cast(DebugCommand): - ranks: List[int] | None + actor_name: str + ranks: RanksType command: str - @classmethod - def _parse(cls, parts: List[str]) -> "Cast": - if len(parts) < 3: - raise ValueError( - "Invalid cast command. Expected: cast { | *} " - ) - str_ranks = parts[1] - command = " ".join(parts[2:]) - if str_ranks == "*": - return cls(None, command) - else: - str_ranks = str_ranks.split(",") - if len(str_ranks) == 0: - raise ValueError( - "Invalid rank list for cast. Expected at least one rank." - ) - ranks = [] - for rank in str_ranks: - try: - ranks.append(int(rank)) - except ValueError: - raise ValueError(f"Invalid rank {rank}. Expected: int") - return cls(ranks, command) - class DebugClient(Actor): """ @@ -244,7 +437,7 @@ class DebugClient(Actor): """ def __init__(self) -> None: - self.sessions = {} # rank -> DebugSession + self.sessions = DebugSessions() @endpoint async def wait_pending_session(self): @@ -252,89 +445,96 @@ async def wait_pending_session(self): await asyncio.sleep(1) @endpoint - async def list(self) -> List[Tuple[int, Dict[str, int], str, ActorId, str, int]]: - table_data = [] - for _, session in self.sessions.items(): - info = session.get_info() - table_data.append( + async def list(self) -> List[DebugSessionInfo]: + session_info = sorted(self.sessions.info()) + print( + tabulate( ( - info.rank, - info.coords, - info.hostname, - info.actor_id, - info.function, - info.lineno, - ) + ( + info.actor_name, + info.rank, + info.coords, + info.hostname, + info.function, + info.lineno, + ) + for info in session_info + ), + headers=[ + "Actor Name", + "Rank", + "Coords", + "Hostname", + "Function", + "Line No.", + ], + tablefmt="grid", ) - table_data = sorted(table_data, key=lambda r: r[0]) - - headers = ["Rank", "Coords", "Hostname", "Actor ID", "Function", "Line No."] - print(tabulate(table_data, headers=headers, tablefmt="grid")) - - return table_data + ) + return session_info @endpoint async def enter(self) -> None: - # pyre-ignore - await getattr(self, "list")._method(self) # noqa + await asyncio.sleep(0.5) + logger.info("Remote breakpoint hit. Entering monarch debugger...") + print("\n\n************************ MONARCH DEBUGGER ************************") + print("Enter 'help' for a list of commands.") + print("Enter 'list' to show all active breakpoints.\n") while True: try: user_input = await _debugger_input("monarch_dbg> ") + if not user_input.strip(): + continue command = DebugCommand.parse(user_input) if isinstance(command, Help): print("monarch_dbg commands:") - print("\tattach - attach to a debug session") + print("\tattach - attach to a debug session") print("\tlist - list all debug sessions") print("\tquit - exit the debugger, leaving all sessions in place") print( - "\tcast { | *} - send a command to a comma-separated list of ranks, or all ranks" + "\tcast ranks(...) - send a command to a set of ranks on the specified actor mesh.\n" + "\t\tThe value inside ranks(...) can be a single rank (ranks(1)),\n" + "\t\ta list of ranks (ranks(1,4,6)), a range of ranks (ranks(start?:stop?:step?)),\n" + "\t\tor a dict of dimensions (ranks(dim1=1:5:2,dim2=3, dim4=(3,6)))." ) print( "\tcontinue - tell all ranks to continue execution, then exit the debugger" ) print("\thelp - print this help message") elif isinstance(command, Attach): - if command.rank not in self.sessions: - print(f"No debug session for rank {command.rank}") - else: - await self.sessions[command.rank].attach() + await self.sessions.get(command.actor_name, command.rank).attach() elif isinstance(command, ListCommand): - await getattr(self, "list")._method(self) # noqa + # pyre-ignore + await self.list._method(self) elif isinstance(command, Continue): - # Make sure all ranks have exited their debug sessions. - # If we sent "quit", it would raise BdbQuit, crashing - # the process, which probably isn't what we want. + # Clear all breakpoints and make sure all ranks have + # exited their debug sessions. If we sent "quit", it + # would raise BdbQuit, crashing the process, which + # probably isn't what we want. + await self._cast_input_and_wait("clear") while len(self.sessions) > 0: - tasks = [] - for rank in self.sessions: - tasks.append( - self.sessions[rank].attach("c", suppress_output=True) - ) - await asyncio.gather(*tasks) + await self._cast_input_and_wait("c") return elif isinstance(command, Quit): return elif isinstance(command, Cast): - if command.ranks is None: - ranks = self.sessions.keys() - else: - ranks = command.ranks - tasks = [] - for rank in ranks: - if rank in self.sessions: - tasks.append( - self.sessions[rank].attach( - command.command, - suppress_output=True, - ) - ) - else: - print(f"No debug session for rank {rank}") - await asyncio.gather(*tasks) + await self._cast_input_and_wait( + command.command, (command.actor_name, command.ranks) + ) except Exception as e: print(f"Error processing command: {e}") + async def _cast_input_and_wait( + self, + command: str, + selection: Optional[Tuple[str, Optional[RanksType]]] = None, + ) -> None: + tasks = [] + for session in self.sessions.iter(selection): + tasks.append(session.attach(command, suppress_output=True)) + await asyncio.gather(*tasks) + ########################################################################## # Debugger APIs # @@ -342,36 +542,83 @@ async def enter(self) -> None: # and communicate with them. @endpoint async def debugger_session_start( - self, rank: int, coords: Dict[str, int], hostname: str, actor_id: ActorId + self, rank: int, coords: Dict[str, int], hostname: str, actor_name: str ) -> None: # Create a session if it doesn't exist - if rank not in self.sessions: - self.sessions[rank] = DebugSession(rank, coords, hostname, actor_id) + if (actor_name, rank) not in self.sessions: + self.sessions.insert(DebugSession(rank, coords, hostname, actor_name)) @endpoint - async def debugger_session_end(self, rank: int) -> None: + async def debugger_session_end(self, actor_name: str, rank: int) -> None: """Detach from the current debug session.""" - session = self.sessions.pop(rank) - await session.detach() + await self.sessions.remove(actor_name, rank).detach() @endpoint - async def debugger_read(self, rank: int, size: int) -> DebuggerWrite | str: + async def debugger_read( + self, actor_name: str, rank: int, size: int + ) -> DebuggerWrite | str: """Read from the debug session for the given rank.""" - session = self.sessions[rank] - - return await session.debugger_read(size) + return await self.sessions.get(actor_name, rank).debugger_read(size) @endpoint - async def debugger_write(self, rank: int, write: DebuggerWrite) -> None: + async def debugger_write( + self, actor_name: str, rank: int, write: DebuggerWrite + ) -> None: """Write to the debug session for the given rank.""" - session = self.sessions[rank] - await session.debugger_write(write) + await self.sessions.get(actor_name, rank).debugger_write(write) + +class DebugManager(Actor): + @staticmethod + @functools.cache + def ref() -> "DebugManager": + ctx = MonarchContext.get() + return cast( + DebugManager, + ActorMeshRef( + DebugManager, + _ActorMeshRefImpl.from_actor_id( + ctx.mailbox, + ActorId.from_string( + f"{ctx.proc_id}.{_DEBUG_MANAGER_ACTOR_NAME}[0]" + ), + ), + ctx.mailbox, + ), + ) + + def __init__(self, debug_client: DebugClient) -> None: + self._debug_client = debug_client + + # pyre-ignore + @endpoint + def get_debug_client(self) -> DebugClient: + return self._debug_client + + +def remote_breakpointhook(): + frame = inspect.currentframe() + assert frame is not None + frame = frame.f_back + assert frame is not None + file = frame.f_code.co_filename + line = frame.f_lineno + module = frame.f_globals.get("__name__", "__main__") + if module == "__main__" and not os.path.exists(file): + raise NotImplementedError( + f"Remote debugging not supported for breakpoint at {file}:{line} because " + f"it is defined inside __main__, and the file does not exist on the host. " + "In this case, cloudpickle serialization does not interact nicely with pdb. " + "To debug your code, move it out of __main__ and into a module that " + "exists on both your client and worker processes." + ) -async def init_debugging( - actor_mesh: ActorMeshRef, -) -> ActorMeshRef[DebugClient]: - debugger_proc_mesh = await local_proc_mesh(gpus=1, hosts=1) - debug_client_mesh = await debugger_proc_mesh.spawn("debug_client", DebugClient) - await actor_mesh._set_debug_client.call(debug_client_mesh) - return debug_client_mesh + ctx = MonarchContext.get() + pdb_wrapper = PdbWrapper( + ctx.point.rank, + ctx.point.shape.coordinates(ctx.point.rank), + ctx.mailbox.actor_id, + DebugManager.ref().get_debug_client.call_one().get(), + ) + DebugContext.set(DebugContext(pdb_wrapper)) + pdb_wrapper.set_trace(frame) diff --git a/python/monarch/_src/actor/pdb_wrapper.py b/python/monarch/_src/actor/pdb_wrapper.py index 87031dd4..a0d2015c 100644 --- a/python/monarch/_src/actor/pdb_wrapper.py +++ b/python/monarch/_src/actor/pdb_wrapper.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import bdb import inspect import io @@ -45,35 +46,43 @@ def __init__( super().__init__(stdout=WriteWrapper(self), stdin=ReadWrapper.create(self)) self._first = True - def setup(self, *args, **kwargs): - r = super().setup(*args, **kwargs) - if self._first: - self._first = False - # when we enter the debugger, we want to present the user's stack frame - # not the nested one inside session.run. This means that the local - # variables are what gets printed, etc. To do this - # we first execute up 2 to get to that frame. - self.do_up(2) - return r - - def set_continue(self) -> None: - r = super().set_continue() - if not self.breaks: - # no more breakpoints so this debugger will not - # be used again, and we detach from the controller io. - self.client_ref.debugger_session_end.call_one(self.rank).get() - # break cycle with itself before we exit - self.stdin = sys.stdin - self.stdout = sys.stdout - return r - - def set_trace(self): + def set_trace(self, frame=None): self.client_ref.debugger_session_start.call_one( - self.rank, self.coords, socket.getfqdn(socket.gethostname()), self.actor_id + self.rank, + self.coords, + socket.getfqdn(socket.gethostname()), + self.actor_id.actor_name, ).get() if self.header: self.message(self.header) - super().set_trace() + super().set_trace(frame) + + def do_clear(self, arg): + if not arg: + # Sending `clear` without any argument specified will + # request confirmation from the user using the `input` function, + # which bypasses our ReadWrapper and causes a hang on the client. + # To avoid this, we just clear all breakpoints instead without + # confirmation. + super().clear_all_breaks() + else: + super().do_clear(arg) + + def end_debug_session(self): + self.client_ref.debugger_session_end.call_one( + self.actor_id.actor_name, self.rank + ).get() + # Once the debug client actor is notified of the session being over, + # we need to prevent any additional requests being sent for the session + # by redirecting stdin and stdout. + self.stdin = sys.stdin + self.stdout = sys.stdout + + def post_mortem(self, exc_tb): + self._first = False + # See builtin implementation of pdb.post_mortem() for reference. + self.reset() + self.interaction(None, exc_tb) class ReadWrapper(io.RawIOBase): @@ -82,7 +91,7 @@ def __init__(self, session: "PdbWrapper"): def readinto(self, b): response = self.session.client_ref.debugger_read.call_one( - self.session.rank, len(b) + self.session.actor_id.actor_name, self.session.rank, len(b) ).get() if response == "detach": # this gets injected by the worker event loop to @@ -116,6 +125,7 @@ def write(self, s: str): # pyre-ignore lineno = self.session.curframe.f_lineno self.session.client_ref.debugger_write.call_one( + self.session.actor_id.actor_name, self.session.rank, DebuggerWrite( s.encode(), @@ -126,10 +136,3 @@ def write(self, s: str): def flush(self): pass - - -def remote_breakpointhook( - rank: int, coords: Dict[str, int], actor_id: ActorId, client_ref: "DebugClient" -): - ds = PdbWrapper(rank, coords, actor_id, client_ref) - ds.set_trace() diff --git a/python/monarch/_src/actor/proc_mesh.py b/python/monarch/_src/actor/proc_mesh.py index 483ddf00..0df67df9 100644 --- a/python/monarch/_src/actor/proc_mesh.py +++ b/python/monarch/_src/actor/proc_mesh.py @@ -38,6 +38,11 @@ from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator from monarch._src.actor.code_sync import RsyncMeshClient, WorkspaceLocation from monarch._src.actor.code_sync.auto_reload import AutoReloadActor +from monarch._src.actor.debugger import ( + _DEBUG_MANAGER_ACTOR_NAME, + DebugClient, + DebugManager, +) from monarch._src.actor.device_utils import _local_device_count from monarch._src.actor.future import Future @@ -83,11 +88,13 @@ def __init__( hy_proc_mesh: HyProcMesh, _mock_shape: Optional[Shape] = None, _device_mesh: Optional["DeviceMesh"] = None, + _is_initializing_debugger: bool = False, ) -> None: self._proc_mesh = hy_proc_mesh self._mock_shape: Optional[Shape] = _mock_shape # type: ignore[21] self._rdma_manager: Optional["RDMAManager"] = None + self._debug_manager: Optional[DebugManager] = None self._mailbox: Mailbox = self._proc_mesh.client self._rsync_mesh_client: Optional[RsyncMeshClient] = None self._auto_reload_actor: Optional[AutoReloadActor] = None @@ -96,6 +103,10 @@ def __init__( if _mock_shape is None and HAS_TENSOR_ENGINE: # type: ignore[21] self._rdma_manager = self._spawn_blocking("rdma_manager", RDMAManager) + if not _is_initializing_debugger: + self._debug_manager = self._spawn_blocking( + _DEBUG_MANAGER_ACTOR_NAME, DebugManager, debug_client() + ) @property def _shape(self) -> Shape: @@ -296,13 +307,21 @@ async def local_proc_mesh_nonblocking( return await ProcMesh.from_alloc(alloc) -def local_proc_mesh_blocking(*, gpus: Optional[int] = None, hosts: int = 1) -> ProcMesh: +def local_proc_mesh_blocking( + *, + gpus: Optional[int] = None, + hosts: int = 1, + _is_initializing_debugger: bool = False, +) -> ProcMesh: if gpus is None: gpus = _local_device_count() spec = AllocSpec(AllocConstraints(), gpus=gpus, hosts=hosts) allocator = LocalAllocator() alloc = allocator.allocate(spec).get() - return ProcMesh.from_alloc(alloc).get() + return ProcMesh( + HyProcMesh.allocate_blocking(alloc), + _is_initializing_debugger=_is_initializing_debugger, + ) def local_proc_mesh(*, gpus: Optional[int] = None, hosts: int = 1) -> Future[ProcMesh]: @@ -371,3 +390,33 @@ def proc_mesh( lambda: proc_mesh_nonblocking(gpus=gpus, hosts=hosts, env=env), lambda: proc_mesh_blocking(gpus=gpus, hosts=hosts, env=env), ) + + +_debug_proc_mesh: Optional["ProcMesh"] = None + + +# Lazy init of the debug proc mesh so that importing monarch.proc_mesh +# doesn't trigger the debug client to spawn, which could cause confusing +# logs. This is defined in proc_mesh.py instead of debugger.py for +# circular import reasons. +def _get_debug_proc_mesh() -> "ProcMesh": + global _debug_proc_mesh + if _debug_proc_mesh is None: + _debug_proc_mesh = local_proc_mesh_blocking( + gpus=1, hosts=1, _is_initializing_debugger=True + ) + return _debug_proc_mesh + + +_debug_client_mesh: Optional[ActorMeshRef[DebugClient]] = None + + +# Lazy init for the same reason as above. This is defined in proc_mesh.py +# instead of debugger.py for circular import reasons. +def debug_client() -> ActorMeshRef[DebugClient]: + global _debug_client_mesh + if _debug_client_mesh is None: + _debug_client_mesh = ( + _get_debug_proc_mesh().spawn("debug_client", DebugClient).get() + ) + return _debug_client_mesh diff --git a/python/monarch/actor/__init__.py b/python/monarch/actor/__init__.py index a2720198..0a108a2b 100644 --- a/python/monarch/actor/__init__.py +++ b/python/monarch/actor/__init__.py @@ -23,7 +23,13 @@ ValueMesh, ) from monarch._src.actor.future import Future -from monarch._src.actor.proc_mesh import local_proc_mesh, proc_mesh, ProcMesh +from monarch._src.actor.proc_mesh import ( + debug_client, + local_proc_mesh, + proc_mesh, + ProcMesh, +) + __all__ = [ "Accumulator", @@ -42,4 +48,5 @@ "ProcMesh", "send", "ValueMesh", + "debug_client", ] diff --git a/python/tests/test_debugger.py b/python/tests/test_debugger.py new file mode 100644 index 00000000..7cd12e47 --- /dev/null +++ b/python/tests/test_debugger.py @@ -0,0 +1,643 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe +import asyncio +import re +import sys +from typing import cast, List +from unittest.mock import AsyncMock, patch + +import monarch +import monarch.actor as actor + +import pytest + +import torch + +from monarch._src.actor.actor_mesh import Actor, ActorError, endpoint, MonarchContext +from monarch._src.actor.debugger import ( + Attach, + Cast, + Continue, + DebugCommand, + DebugSession, + DebugSessionInfo, + DebugSessions, + Help, + ListCommand, + Quit, +) + +from monarch._src.actor.proc_mesh import proc_mesh + +needs_cuda = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) + + +def _bad_rank(): + raise ValueError("bad rank") + + +def _debugee_actor_internal(rank): + if rank == 0: + breakpoint() # noqa + rank += 1 + rank += 1 + return rank + elif rank == 1: + breakpoint() # noqa + rank += 2 + rank += 2 + return rank + elif rank == 2: + breakpoint() # noqa + rank += 3 + rank += 3 + _bad_rank() + elif rank == 3: + breakpoint() # noqa + rank += 4 + rank += 4 + return rank + + +class DebugeeActor(Actor): + @endpoint + async def to_debug(self): + rank = MonarchContext.get().point.rank + return _debugee_actor_internal(rank) + + +async def _wait_for_breakpoints(debug_client, n_breakpoints) -> List[DebugSessionInfo]: + breakpoints: List[DebugSessionInfo] = [] + for i in range(10): + breakpoints = await debug_client.list.call_one() + if len(breakpoints) == n_breakpoints: + break + await asyncio.sleep(1) + if i == 9: + raise RuntimeError("timed out waiting for breakpoints") + return breakpoints + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach debugee 1", + "n", + "n", + "n", + "n", + "detach", + "attach debugee 1", + "detach", + "quit", + "cast debugee ranks(0,3) n", + "cast debugee ranks(0,3) n", + # Attaching to 0 and 3 ensures that when we call "list" + # the next time, their function/lineno info will be + # up-to-date. + "attach debugee 0", + "detach", + "attach debugee 3", + "detach", + "quit", + "attach debugee 2", + "c", + "detach", + "quit", + "attach debugee 2", + "bt", + "c", + "quit", + "continue", + ] + + outputs = [] + + def _patch_output(msg): + nonlocal outputs + outputs.append(msg) + + with patch( + "monarch._src.actor.debugger._debugger_input", side_effect=input_mock + ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): + proc = await proc_mesh(hosts=2, gpus=2) + debugee = await proc.spawn("debugee", DebugeeActor) + debug_client = actor.debug_client() + + fut = debugee.to_debug.call() + await debug_client.wait_pending_session.call_one() + breakpoints = await _wait_for_breakpoints(debug_client, 4) + + initial_linenos = {} + for i in range(len(breakpoints)): + info = breakpoints[i] + initial_linenos[info.rank] = info.lineno + assert info.rank == i + assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2} + assert info.function == "test_debugger._debugee_actor_internal" + assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank + + await debug_client.enter.call_one() + + # Check that when detaching and re-attaching to a session, the last portion of the output is repeated + expected_last_output = [ + r"--Return--", + r"\n", + r"> (/.*/)+test_debugger.py\(\d+\)to_debug\(\)->5\n-> return _debugee_actor_internal\(rank\)", + r"\n", + r"\(Pdb\) ", + ] + output_len = len(expected_last_output) + assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] + for real_output, expected_output in zip( + outputs[-output_len:], expected_last_output + ): + assert re.match(expected_output, real_output) is not None + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i].function == "test_debugger.to_debug" + else: + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i].function == "test_debugger.to_debug" + elif i in (0, 3): + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] + 2 + else: + assert ( + breakpoints[i].function == "test_debugger._debugee_actor_internal" + ) + assert breakpoints[i].lineno == initial_linenos[i] + + await debug_client.enter.call_one() + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 4 + # Expect post-mortem debugging for rank 2 + assert breakpoints[2].function == "test_debugger._bad_rank" + + await debug_client.enter.call_one() + + expected_last_output = [ + r"\s*(/.*/)+test_debugger.py\(\d+\)_debugee_actor_internal\(\)\n-> _bad_rank\(\)", + r"\n", + r'> (/.*/)+test_debugger.py\(\d+\)_bad_rank\(\)\n-> raise ValueError\("bad rank"\)', + r"\n", + r"\(Pdb\) ", + ] + + for output, expected_output in zip( + outputs[-len(expected_last_output) :], expected_last_output + ): + assert re.match(expected_output, output) is not None + + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 3 + for i, rank in enumerate((0, 1, 3)): + assert breakpoints[i].rank == rank + + await debug_client.enter.call_one() + breakpoints = await debug_client.list.call_one() + assert len(breakpoints) == 0 + + with pytest.raises( + monarch._src.actor.actor_mesh.ActorError, match="ValueError: bad rank" + ): + await fut + + +@pytest.mark.skipif( + torch.cuda.device_count() < 2, + reason="Not enough GPUs, this test requires at least 2 GPUs", +) +async def test_debug_multi_actor() -> None: + input_mock = AsyncMock() + input_mock.side_effect = [ + "attach debugee_2 2", + "n", + "detach", + "attach debugee_1 1", + "n", + "detach", + "quit", + "cast debugee_1 ranks(:) c", + "cast debugee_2 ranks(:) c", + "attach debugee_2 2", + "c", + "quit", + "continue", + ] + + with patch("monarch._src.actor.debugger._debugger_input", side_effect=input_mock): + proc = await proc_mesh(hosts=2, gpus=2) + debugee_1 = await proc.spawn("debugee_1", DebugeeActor) + debugee_2 = await proc.spawn("debugee_2", DebugeeActor) + debug_client = actor.debug_client() + + fut_1 = debugee_1.to_debug.call() + fut_2 = debugee_2.to_debug.call() + await debug_client.wait_pending_session.call_one() + + breakpoints = await _wait_for_breakpoints(debug_client, 8) + + initial_linenos = {} + for i in range(len(breakpoints)): + info = breakpoints[i] + initial_linenos[info.rank] = info.lineno + assert info.rank == i % 4 + assert info.actor_name == "debugee_1" if i < 4 else "debugee_2" + assert info.coords == {"hosts": info.rank // 2, "gpus": info.rank % 2} + assert info.function == "test_debugger._debugee_actor_internal" + assert info.lineno == cast(int, breakpoints[0].lineno) + 5 * info.rank + + await debug_client.enter.call_one() + + breakpoints = await _wait_for_breakpoints(debug_client, 8) + for i in range(len(breakpoints)): + if i == 1: + assert breakpoints[i].actor_name == "debugee_1" + assert breakpoints[i].rank == 1 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 + elif i == 6: + assert breakpoints[i].actor_name == "debugee_2" + assert breakpoints[i].rank == 2 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + 1 + else: + assert ( + breakpoints[i].actor_name == "debugee_1" if i < 4 else "debugee_2" + ) + assert breakpoints[i].rank == i % 4 + assert breakpoints[i].lineno == initial_linenos[breakpoints[i].rank] + + await debug_client.enter.call_one() + + breakpoints = await _wait_for_breakpoints(debug_client, 1) + with pytest.raises(ActorError, match="ValueError: bad rank"): + await fut_2 + assert breakpoints[0].actor_name == "debugee_1" + assert breakpoints[0].rank == 2 + assert breakpoints[0].function == "test_debugger._bad_rank" + + await debug_client.enter.call_one() + + breakpoints = await _wait_for_breakpoints(debug_client, 0) + with pytest.raises(ActorError, match="ValueError: bad rank"): + await fut_1 + + +async def test_debug_sessions_insert_get_remove() -> None: + mock_sessions = [] + for actor_name in ("actor_a", "actor_b"): + for rank in range(2): + mock_session = DebugSession(rank, {}, "", actor_name) + mock_sessions.append(mock_session) + + debug_sessions = DebugSessions() + + with pytest.raises(ValueError, match="No debug sessions for actor actor_a"): + debug_sessions.get("actor_a", 0) + debug_sessions.insert(mock_sessions[0]) + assert debug_sessions.get("actor_a", 0) is mock_sessions[0] + assert ("actor_a", 0) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 0 already exists for actor actor_a" + ): + debug_sessions.insert(mock_sessions[0]) + + with pytest.raises( + ValueError, match="No debug session for rank 1 for actor actor_a" + ): + debug_sessions.get("actor_a", 1) + debug_sessions.insert(mock_sessions[1]) + assert debug_sessions.get("actor_a", 1) is mock_sessions[1] + assert ("actor_a", 1) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 1 already exists for actor actor_a" + ): + debug_sessions.insert(mock_sessions[1]) + + with pytest.raises(ValueError, match="No debug sessions for actor actor_b"): + debug_sessions.get("actor_b", 0) + debug_sessions.insert(mock_sessions[2]) + assert debug_sessions.get("actor_b", 0) is mock_sessions[2] + assert ("actor_b", 0) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 0 already exists for actor actor_b" + ): + debug_sessions.insert(mock_sessions[2]) + + with pytest.raises( + ValueError, match="No debug session for rank 1 for actor actor_b" + ): + debug_sessions.get("actor_b", 1) + debug_sessions.insert(mock_sessions[3]) + assert debug_sessions.get("actor_b", 1) is mock_sessions[3] + assert ("actor_b", 1) in debug_sessions + with pytest.raises( + ValueError, match="Debug session for rank 1 already exists for actor actor_b" + ): + debug_sessions.insert(mock_sessions[3]) + + assert len(debug_sessions) == 4 + + assert debug_sessions.remove("actor_a", 0) is mock_sessions[0] + assert len(debug_sessions) == 3 + assert ("actor_a", 0) not in debug_sessions + with pytest.raises( + ValueError, match="No debug session for rank 0 for actor actor_a" + ): + debug_sessions.remove("actor_a", 0) + + assert debug_sessions.remove("actor_a", 1) is mock_sessions[1] + assert len(debug_sessions) == 2 + assert ("actor_a", 1) not in debug_sessions + with pytest.raises(ValueError, match="No debug sessions for actor actor_a"): + debug_sessions.remove("actor_a", 1) + + assert debug_sessions.remove("actor_b", 0) is mock_sessions[2] + assert len(debug_sessions) == 1 + assert ("actor_b", 0) not in debug_sessions + with pytest.raises( + ValueError, match="No debug session for rank 0 for actor actor_b" + ): + debug_sessions.remove("actor_b", 0) + + assert debug_sessions.remove("actor_b", 1) is mock_sessions[3] + assert len(debug_sessions) == 0 + assert ("actor_b", 1) not in debug_sessions + with pytest.raises(ValueError, match="No debug sessions for actor actor_b"): + debug_sessions.remove("actor_b", 1) + + +async def test_debug_sessions_iter() -> None: + debug_sessions = DebugSessions() + mock_sessions = [] + + for actor_name in ("actor_a", "actor_b"): + for host in range(3): + for gpu in range(8): + rank = host * 8 + gpu + mock_session = DebugSession( + rank, {"hosts": host, "gpus": gpu}, "", actor_name + ) + mock_sessions.append(mock_session) + debug_sessions.insert(mock_session) + + # Single rank + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = list(debug_sessions.iter((actor_name, 2))) + assert len(sessions) == 1 + assert sessions[0] is mock_sessions[i * 24 + 2] + + # List of ranks + ranks = [1, 3, 5] + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, ranks)), key=lambda s: s.get_info() + ) + assert len(sessions) == 3 + for j in range(3): + assert sessions[j] is mock_sessions[i * 24 + ranks[j]] + + # Range of ranks + ranks = range(2, 24, 3) + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, ranks)), key=lambda s: s.get_info() + ) + ranks = list(ranks) + assert len(sessions) == len(ranks) + for j in range(len(ranks)): + assert sessions[j] is mock_sessions[i * 24 + ranks[j]] + + # All ranks + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, None)), key=lambda s: s.get_info() + ) + assert len(sessions) == 24 + for j in range(24): + assert sessions[j] is mock_sessions[i * 24 + j] + + # All ranks, all actors + sessions = sorted(debug_sessions.iter(None), key=lambda s: s.get_info()) + assert len(sessions) == 48 + for i in range(48): + assert sessions[i] is mock_sessions[i] + + # Dimension filtering with a single value + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": 1})), key=lambda s: s.get_info() + ) + assert len(sessions) == 8 + for j in range(8): + assert sessions[j] is mock_sessions[i * 24 + 8 + j] + + # Dimension filtering with a list + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": [0, 2]})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 16 + j = 0 + for host in (0, 2): + for gpu in range(8): + assert sessions[j] is mock_sessions[i * 24 + host * 8 + gpu] + j += 1 + + # Dimension filtering with a range + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter((actor_name, {"gpus": range(5, 8)})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 9 + j = 0 + for host in range(3): + for gpu in range(5, 8): + assert sessions[j] is mock_sessions[i * 24 + host * 8 + gpu] + j += 1 + + # Multiple dimension filters + for i, actor_name in enumerate(("actor_a", "actor_b")): + sessions = sorted( + debug_sessions.iter( + (actor_name, {"hosts": [1, 3], "gpus": range(0, sys.maxsize, 3)}) + ), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 3 + j = 0 + for gpu in range(0, 8, 3): + assert sessions[j] is mock_sessions[i * 24 + 8 + gpu] + j += 1 + + # Non-existent dimension + for actor_name in ("actor_a", "actor_b"): + sessions = sorted( + debug_sessions.iter((actor_name, {"hosts": 0, "gpus": 0, "foo": 0})), + key=lambda s: s.get_info(), + ) + assert len(sessions) == 0 + + +@pytest.mark.parametrize( + ["user_input", "expected_output"], + [ + ("attach debugee 1", Attach("debugee", 1)), + ("a my_awesome_actor 100", Attach("my_awesome_actor", 100)), + ("list", ListCommand()), + ("l", ListCommand()), + ("help", Help()), + ("h", Help()), + ("quit", Quit()), + ("q", Quit()), + ("continue", Continue()), + ("c", Continue()), + ( + "cast debugee ranks(123) b 25", + Cast(actor_name="debugee", ranks=123, command="b 25"), + ), + ( + "cast my_awesome_actor ranks(12,34,56) b 25", + Cast(actor_name="my_awesome_actor", ranks=[12, 34, 56], command="b 25"), + ), + ( + "cast debugee ranks(:) b 25", + Cast(actor_name="debugee", ranks=range(0, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(:123) b 25", + Cast(actor_name="debugee", ranks=range(0, 123), command="b 25"), + ), + ( + "cast debugee ranks(123:) b 25", + Cast(actor_name="debugee", ranks=range(123, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(123:456) b 25", + Cast(actor_name="debugee", ranks=range(123, 456), command="b 25"), + ), + ( + "cast debugee ranks(::) b 25", + Cast(actor_name="debugee", ranks=range(0, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(::123) b 25", + Cast( + actor_name="debugee", ranks=range(0, sys.maxsize, 123), command="b 25" + ), + ), + ( + "cast debugee ranks(123::) b 25", + Cast(actor_name="debugee", ranks=range(123, sys.maxsize), command="b 25"), + ), + ( + "cast debugee ranks(:123:) b 25", + Cast(actor_name="debugee", ranks=range(0, 123), command="b 25"), + ), + ( + "cast debugee ranks(:456:123) b 25", + Cast(actor_name="debugee", ranks=range(0, 456, 123), command="b 25"), + ), + ( + "cast debugee ranks(456::123) b 25", + Cast( + actor_name="debugee", ranks=range(456, sys.maxsize, 123), command="b 25" + ), + ), + ( + "cast debugee ranks(123:456:) b 25", + Cast(actor_name="debugee", ranks=range(123, 456), command="b 25"), + ), + ( + "cast debugee ranks(456:789:123) b 25", + Cast(actor_name="debugee", ranks=range(456, 789, 123), command="b 25"), + ), + ( + "cast debugee ranks(dim1=123) up 2", + Cast(actor_name="debugee", ranks={"dim1": 123}, command="up 2"), + ), + ( + "cast debugee ranks(dim1=123, dim2=(12,34,56), dim3=15::2) up 2", + Cast( + actor_name="debugee", + ranks={ + "dim1": 123, + "dim2": [12, 34, 56], + "dim3": range(15, sys.maxsize, 2), + }, + command="up 2", + ), + ), + ], +) +async def test_debug_command_parser_valid_inputs(user_input, expected_output): + assert DebugCommand.parse(user_input) == expected_output + + +@pytest.mark.parametrize( + "invalid_input", + [ + "", + "a", + "attach", + "a actor", + "attach actor", + "attacha actor 1" "attch actor 1", + "attach actor 1abc", + "attach actor 1 a", + "cast ranks(123) b 25", + "cast ranks(123) b 25", + "castactor ranks(123) b 25", + "cast actor rnks(123) b 25", + "cast actor ranks() b 25", + "cast actor ranks(1ab) b 25", + "cast actor ranks(1,a,3) b 25", + "cast actor ranks(a:2:4) b 25", + "cast actor ranks(1,2,3", + "cast actor ranks(1,2,3)) b 25", + "cast actor ranks(1,) b 25", + "cast actor ranks(1,2,) b 25", + "cast actor ranks(,1,2) b 25", + "cast actor ranks(1,,2) b 25", + "cast actor ranks(:::) b 25", + "cast actor ranks(:123::) b 25", + "cast actor ranks(1:2:3,4) b 25", + "cast actor ranks(dim1=) b 25", + "cast actor ranks(dim1=123, dim2=) b 25", + "cast actor ranks(dim1=123, dim2=(12,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(,12,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(12,,34,56) b 25", + "cast actor ranks(dim1=123, dim2=(12,34,56), dim3=15::2 b 25", + "cast actor ranks(dim1=123,) b 25", + ], +) +async def test_debug_command_parser_invalid_inputs(invalid_input): + assert DebugCommand.parse(invalid_input) is None diff --git a/python/tests/test_python_actors.py b/python/tests/test_python_actors.py index 7092844e..821872e9 100644 --- a/python/tests/test_python_actors.py +++ b/python/tests/test_python_actors.py @@ -4,30 +4,26 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe import asyncio import operator -import re import threading import time from types import ModuleType -from unittest.mock import AsyncMock, patch import pytest import torch -from monarch._src.actor.debugger import init_debugging -from monarch._src.actor.proc_mesh import local_proc_mesh from monarch.actor import ( Accumulator, Actor, - ActorError, current_actor_name, current_rank, current_size, endpoint, Future, - MonarchContext, + local_proc_mesh, proc_mesh, ) from monarch.rdma import RDMABuffer @@ -408,146 +404,6 @@ def test_proc_mesh_liveness() -> None: counter.value.call().get() -def _debugee_actor_internal(rank): - if rank == 0: - breakpoint() # noqa - rank += 1 - return rank - elif rank == 1: - breakpoint() # noqa - rank += 2 - return rank - elif rank == 2: - breakpoint() # noqa - rank += 3 - raise ValueError("bad rank") - elif rank == 3: - breakpoint() # noqa - rank += 4 - return rank - - -class DebugeeActor(Actor): - @endpoint - async def to_debug(self): - rank = MonarchContext.get().point.rank - return _debugee_actor_internal(rank) - - -async def test_debug() -> None: - input_mock = AsyncMock() - input_mock.side_effect = [ - "attach 1", - "n", - "n", - "n", - "n", - "detach", - "attach 1", - "detach", - "quit", - "cast 0,3 n", - "cast 0,3 n", - # Attaching to 0 and 3 ensures that when we call "list" - # the next time, their function/lineno info will be - # up-to-date. - "attach 0", - "detach", - "attach 3", - "detach", - "quit", - "attach 2", - "c", - "quit", - "continue", - ] - - outputs = [] - - def _patch_output(msg): - nonlocal outputs - outputs.append(msg) - - with patch( - "monarch._src.actor.debugger._debugger_input", side_effect=input_mock - ), patch("monarch._src.actor.debugger._debugger_output", new=_patch_output): - proc = await proc_mesh(hosts=2, gpus=2) - debugee = await proc.spawn("debugee", DebugeeActor) - debug_client = await init_debugging(debugee) - - fut = debugee.to_debug.call() - await debug_client.wait_pending_session.call_one() - breakpoints = [] - for i in range(10): - breakpoints = await debug_client.list.call_one() - if len(breakpoints) == 4: - break - await asyncio.sleep(1) - if i == 9: - raise RuntimeError("timed out waiting for breakpoints") - - initial_linenos = {} - for i in range(len(breakpoints)): - rank, coords, _, _, function, lineno = breakpoints[i] - initial_linenos[rank] = lineno - assert rank == i - assert coords == {"hosts": rank // 2, "gpus": rank % 2} - assert function == "test_python_actors._debugee_actor_internal" - assert lineno == breakpoints[0][5] + 4 * rank - - await debug_client.enter.call_one() - - # Check that when detaching and re-attaching to a session, the last portion of the output is repeated - expected_last_output = [ - r"--Return--", - r"\n", - r"> (/.*/)+test_python_actors.py\(\d+\)to_debug\(\)->3\n-> return _debugee_actor_internal\(rank\)", - r"\n", - r"\(Pdb\) ", - ] - output_len = len(expected_last_output) - assert outputs[-2 * output_len : -output_len] == outputs[-output_len:] - for real_output, expected_output in zip( - outputs[-output_len:], expected_last_output - ): - assert re.match(expected_output, real_output) is not None - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - for i in range(len(breakpoints)): - if i == 1: - assert breakpoints[i][4] == "test_python_actors.to_debug" - elif i in (0, 3): - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] + 2 - else: - assert breakpoints[i][4] == "test_python_actors._debugee_actor_internal" - assert breakpoints[i][5] == initial_linenos[i] - - await debug_client.enter.call_one() - - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 3 - for i, rank in enumerate((0, 1, 3)): - assert breakpoints[i][0] == rank - - await debug_client.enter.call_one() - breakpoints = await debug_client.list.call_one() - assert len(breakpoints) == 0 - - with pytest.raises(ActorError, match="ValueError: bad rank"): - await fut - - class TLSActor(Actor): """An actor that manages thread-local state.""" diff --git a/requirements.txt b/requirements.txt index 4235044e..41a2a230 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,3 +5,4 @@ numpy pyre-extensions cloudpickle torchx-nightly +lark