Skip to content

Big debugging update #456

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 48 additions & 19 deletions python/monarch/_src/actor/actor_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import inspect
import logging
import random
import sys
import traceback

from dataclasses import dataclass
Expand Down Expand Up @@ -53,15 +52,12 @@
from monarch._rust_bindings.monarch_hyperactor.telemetry import enter_span, exit_span
from monarch._src.actor.allocator import LocalAllocator, ProcessAllocator
from monarch._src.actor.future import Future
from monarch._src.actor.pdb_wrapper import remote_breakpointhook
from monarch._src.actor.pdb_wrapper import PdbWrapper

from monarch._src.actor.pickle import flatten, unpickle

from monarch._src.actor.shape import MeshTrait, NDSlice

if TYPE_CHECKING:
from monarch._src.actor.debugger import DebugClient

logger: logging.Logger = logging.getLogger(__name__)

Allocator = ProcessAllocator | LocalAllocator
Expand Down Expand Up @@ -97,6 +93,23 @@ def get() -> "MonarchContext":
)


@dataclass
class DebugContext:
pdb_wrapper: Optional[PdbWrapper] = None

@staticmethod
def get() -> "DebugContext":
return _debug_context.get()

@staticmethod
def set(debug_context: "DebugContext") -> None:
_debug_context.set(debug_context)


_debug_context: contextvars.ContextVar[DebugContext] = contextvars.ContextVar(
"monarch.actor_mesh._debug_context"
)

T = TypeVar("T")
P = ParamSpec("P")
R = TypeVar("R")
Expand Down Expand Up @@ -538,6 +551,8 @@ async def handle_cast(
)
_context.set(ctx)

DebugContext.set(DebugContext())

args, kwargs = unpickle(message.message, mailbox)

if message.method == "__init__":
Expand Down Expand Up @@ -574,9 +589,10 @@ async def instrumented():
)
try:
result = await the_method(self.instance, *args, **kwargs)
self._maybe_exit_debugger()
except Exception as e:
logging.critical(
"Unahndled exception in actor endpoint",
"Unhandled exception in actor endpoint",
exc_info=e,
)
raise e
Expand All @@ -589,11 +605,13 @@ async def instrumented():
the_method.__module__, message.method, str(ctx.mailbox.actor_id)
)
result = the_method(self.instance, *args, **kwargs)
self._maybe_exit_debugger()
exit_span()

if port is not None:
port.send("result", result)
except Exception as e:
self._post_mortem_debug(e.__traceback__)
traceback.print_exc()
s = ActorError(e)

Expand All @@ -604,6 +622,7 @@ async def instrumented():
else:
raise s from None
except BaseException as e:
self._post_mortem_debug(e.__traceback__)
# A BaseException can be thrown in the case of a Rust panic.
# In this case, we need a way to signal the panic to the Rust side.
# See [Panics in async endpoints]
Expand All @@ -614,6 +633,29 @@ async def instrumented():
pass
raise

def _maybe_exit_debugger(self, do_continue=True) -> None:
if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
if do_continue:
pdb_wrapper.clear_all_breaks()
pdb_wrapper.do_continue("")
pdb_wrapper.end_debug_session()
DebugContext.set(DebugContext())

def _post_mortem_debug(self, exc_tb) -> None:
from monarch._src.actor.debugger import DebugManager

if (pdb_wrapper := DebugContext.get().pdb_wrapper) is not None:
ctx = MonarchContext.get()
pdb_wrapper = PdbWrapper(
ctx.point.rank,
ctx.point.shape.coordinates(ctx.point.rank),
ctx.mailbox.actor_id,
DebugManager.ref().get_debug_client.call_one().get(),
)
DebugContext.set(DebugContext(pdb_wrapper))
pdb_wrapper.post_mortem(exc_tb)
self._maybe_exit_debugger(do_continue=False)


def _is_mailbox(x: object) -> bool:
return isinstance(x, Mailbox)
Expand Down Expand Up @@ -648,19 +690,6 @@ def _new_with_shape(self, shape: Shape) -> "ActorMeshRef":
"actor implementations are not meshes, but we can't convince the typechecker of it..."
)

@endpoint # pyre-ignore
def _set_debug_client(self, client: "DebugClient") -> None:
point = MonarchContext.get().point
# For some reason, using a lambda instead of functools.partial
# confuses the pdb wrapper implementation.
sys.breakpointhook = functools.partial( # pyre-ignore
remote_breakpointhook,
point.rank,
point.shape.coordinates(point.rank),
MonarchContext.get().mailbox.actor_id,
client,
)


class ActorMeshRef(MeshTrait, Generic[T]):
def __init__(
Expand Down
4 changes: 4 additions & 0 deletions python/monarch/_src/actor/bootstrap_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,10 @@ def invoke_main():
except Exception as e:
logging.warning(f"Failed to set up py-spy: {e}")

from monarch._src.actor.debugger import remote_breakpointhook

sys.breakpointhook = remote_breakpointhook

# Start an event loop for PythonActors to use.
asyncio.run(main())

Expand Down
Loading
Loading