Skip to content

Commit a66ee02

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Big debugging update
Summary: This diff contains several updates to the monarch actor mesh debugging experience. - Improved debug input parsing, with the new `cast` command supporting more sophisticated rank selection grammar: - `cast ranks(3) pdb_command`: send `pdb_command` to rank 3. - `cast ranks(1,3,5) pdb_command`: send `pdb_command` to ranks 1, 3 and 5. - `cast ranks(1:10:2) pdb_command`: send `pdb_command` to the ranks in `range(start=1, stop=10, step=2)`. - `cast ranks(pp=2, dp=(1,3), tp=2:8) pdb_command`: send `pdb_command` to ranks with `pp` dim 2, `dp` dim 1 or 3, and `tp` dim in `range(2,8)`. - The debug client is now automatically registered with an actor mesh when that actor mesh is spawned. This means calling `init_debugging(actor_mesh)` is no longer necessary. - Debugging now works with MAST jobs, by enforcing that breakpoints aren't set in `__main__`, and that the file containing the breakpoint exists on the remote host. - The first requirement is due to how `cloudpickle` works -- if an actor endpoint is defined inside `__main__`, `cloudpickle` will serialize it by value instead of by reference. When the code then runs on the remote host, it thinks the location of the code is the user's local `__main__` file, which confuses pdb, because the file doesn't exist at the same path (or may not exist at all) on the remote host. - The second requirement is due to important parts of `pdb`'s implementation relying on the ability to search for the file being debugged on the remote host's file system. - A debugging session for a specific rank is now forcefully exited once the endpoint finishes execution. This contains the debugging experience within user-authored code. It is also necessary for preventing hangs, because if pdb is allowed to continue indefinitely, then control flow will eventually bubble back up to the main asyncio event loop on the worker, at which point everything breaks. - Hitting a breakpoint now automatically enables post-mortem debugging, so any rank that encounters an exception after hitting a breakpoint will automatically stop at the exception. Attaching the debugger to that rank should then provide an experience like `pdb.post_mortem()`. ## Next steps/gaps I'm aware of (reviewers please read): - Indexing debug sessions by rank isn't sustainable, because two actor meshes may simultaneously hit breakpoints on the same rank and cause a collision inside the debug client. - Entering the debug client should happen automatically, rather than requiring the user to do `await debug_client().enter.call_one()`. - Casting pdb commands should ideally leverage `MeshTrait` rather than reimplementing the selection logic. - If a mesh was reshaped/renamed so that its dimension names aren't `hosts` and `gpus` anymore, the debugger should reflect the new shape/names. - The user should be able to enable post-mortem debugging without having to hit a separate breakpoint first. Differential Revision: D77568423
1 parent a007830 commit a66ee02

File tree

7 files changed

+767
-292
lines changed

7 files changed

+767
-292
lines changed

python/monarch/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from monarch.future import ActorFuture
5858
from monarch.gradient_generator import grad_function, grad_generator
5959
from monarch.notebook import mast_mesh, reserve_torchx as mast_reserve
60+
from monarch.proc_mesh import debug_client
6061
from monarch.python_local_mesh import python_local_mesh
6162
from monarch.rust_backend_mesh import (
6263
rust_backend_mesh,
@@ -116,6 +117,7 @@
116117
"LocalAllocator": ("monarch.allocator", "LocalAllocator"),
117118
"ActorFuture": ("monarch.future", "ActorFuture"),
118119
"builtins": ("monarch.builtins", "builtins"),
120+
"debug_client": ("monarch.proc_mesh", "debug_client"),
119121
}
120122

121123

@@ -185,5 +187,6 @@ def __getattr__(name):
185187
"LocalAllocator",
186188
"ActorFuture",
187189
"builtins",
190+
"debug_client",
188191
]
189192
assert sorted(__all__) == sorted(_public_api)

python/monarch/actor_mesh.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import io
1414
import itertools
1515
import logging
16+
import os
1617
import random
1718
import sys
1819
import traceback
@@ -61,7 +62,7 @@
6162

6263
from monarch.common.pickle_flatten import flatten, unflatten
6364
from monarch.common.shape import MeshTrait, NDSlice
64-
from monarch.pdb_wrapper import remote_breakpointhook
65+
from monarch.pdb_wrapper import PdbWrapper
6566

6667
if TYPE_CHECKING:
6768
from monarch.debugger import DebugClient
@@ -578,9 +579,10 @@ async def instrumented():
578579
)
579580
try:
580581
result = await the_method(self.instance, *args, **kwargs)
582+
self.instance._maybe_exit_debugger()
581583
except Exception as e:
582584
logging.critical(
583-
"Unahndled exception in actor endpoint",
585+
"Unhandled exception in actor endpoint",
584586
exc_info=e,
585587
)
586588
raise e
@@ -593,11 +595,15 @@ async def instrumented():
593595
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
594596
)
595597
result = the_method(self.instance, *args, **kwargs)
598+
# pyre-ignore
599+
self.instance._maybe_exit_debugger()
596600
exit_span()
597601

598602
if port is not None:
599603
port.send("result", result)
600604
except Exception as e:
605+
# pyre-ignore
606+
self.instance._post_mortem_debug(e.__traceback__)
601607
traceback.print_exc()
602608
s = ActorError(e)
603609

@@ -608,6 +614,7 @@ async def instrumented():
608614
else:
609615
raise s from None
610616
except BaseException as e:
617+
self.instance._post_mortem_debug(e.__traceback__)
611618
# A BaseException can be thrown in the case of a Rust panic.
612619
# In this case, we need a way to signal the panic to the Rust side.
613620
# See [Panics in async endpoints]
@@ -674,18 +681,63 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
674681
)
675682

676683
@endpoint # pyre-ignore
677-
def _set_debug_client(self, client: "DebugClient") -> None:
684+
def _set_debug_client(self, client: "ActorMeshRef[DebugClient]") -> None:
678685
point = MonarchContext.get().point
679686
# For some reason, using a lambda instead of functools.partial
680687
# confuses the pdb wrapper implementation.
681-
sys.breakpointhook = functools.partial( # pyre-ignore
682-
remote_breakpointhook,
688+
sys.breakpointhook = functools.partial(
689+
self._remote_breakpointhook,
683690
point.rank,
684691
point.shape.coordinates(point.rank),
685692
MonarchContext.get().mailbox.actor_id,
686693
client,
687694
)
688695

696+
def _remote_breakpointhook(
697+
self,
698+
rank: int,
699+
coords: Dict[str, int],
700+
actor_id: ActorId,
701+
client_ref: "DebugClient",
702+
) -> None:
703+
frame = inspect.currentframe()
704+
assert frame is not None
705+
frame = frame.f_back
706+
assert frame is not None
707+
file = frame.f_code.co_filename
708+
line = frame.f_lineno
709+
module = frame.f_globals.get("__name__", "__main__")
710+
if module == "__main__" and not os.path.exists(file):
711+
raise NotImplementedError(
712+
f"Remote debugging not supported for breakpoint at {file}:{line} because "
713+
f"it is defined inside __main__, and the file does not exist on the host. "
714+
"In this case, cloudpickle serialization does not interact nicely with pdb. "
715+
"To debug your code, move it out of __main__ and into a module that "
716+
"exists on both your client and worker processes."
717+
)
718+
719+
# pyre-ignore
720+
self._pdb_wrapper_args = (rank, coords, actor_id, client_ref)
721+
# pyre-ignore
722+
self._pdb_wrapper = PdbWrapper(*self._pdb_wrapper_args)
723+
self._pdb_wrapper.set_trace(frame)
724+
725+
def _post_mortem_debug(self, exc_tb) -> None:
726+
if hasattr(self, "_pdb_wrapper"):
727+
# pyre-ignore
728+
self._pdb_wrapper = PdbWrapper(*self._pdb_wrapper_args)
729+
self._pdb_wrapper.post_mortem(exc_tb)
730+
self._maybe_exit_debugger(do_continue=False)
731+
732+
def _maybe_exit_debugger(self, do_continue=True) -> None:
733+
if hasattr(self, "_pdb_wrapper"):
734+
if do_continue:
735+
# pyre-ignore
736+
self._pdb_wrapper.clear_all_breaks()
737+
self._pdb_wrapper.do_continue(None)
738+
self._pdb_wrapper.end_debug_session()
739+
del self._pdb_wrapper
740+
689741

690742
class ActorMeshRef(MeshTrait, Generic[T]):
691743
def __init__(

0 commit comments

Comments
 (0)