diff --git a/tests/test_connection.py b/tests/test_connection.py index 6cccefa..0837aa5 100644 --- a/tests/test_connection.py +++ b/tests/test_connection.py @@ -31,6 +31,7 @@ ''' from __future__ import annotations +import copy from functools import partial, wraps import re import ssl @@ -452,7 +453,6 @@ async def test_open_websocket_internal_ki(nursery, monkeypatch, autojump_clock): Make sure that KI is delivered, and the user exception is in the __cause__ exceptiongroup """ async def ki_raising_ping_handler(*args, **kwargs) -> None: - print("raising ki") raise KeyboardInterrupt monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", ki_raising_ping_handler) async def handler(request): @@ -474,11 +474,14 @@ async def handler(request): async def test_open_websocket_internal_exc(nursery, monkeypatch, autojump_clock): """_reader_task._handle_ping_event triggers ValueError. user code also raises exception. - internal exception is in __cause__ exceptiongroup and user exc is delivered + internal exception is in __context__ exceptiongroup and user exc is delivered """ - my_value_error = ValueError() + internal_error = ValueError() + internal_error.__context__ = TypeError() + user_error = NameError() + user_error_context = KeyError() async def raising_ping_event(*args, **kwargs) -> None: - raise my_value_error + raise internal_error monkeypatch.setattr(WebSocketConnection, "_handle_ping_event", raising_ping_event) async def handler(request): @@ -486,15 +489,17 @@ async def handler(request): await server_ws.ping(b"a") server = await nursery.start(serve_websocket, handler, HOST, 0, None) - with pytest.raises(trio.TooSlowError) as exc_info: + with pytest.raises(type(user_error)) as exc_info: async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): - with trio.fail_after(1) as cs: - cs.shield = True - await trio.sleep(2) + await trio.lowlevel.checkpoint() + user_error.__context__ = user_error_context + raise user_error - e_cause = exc_info.value.__cause__ - assert isinstance(e_cause, _TRIO_EXC_GROUP_TYPE) - assert my_value_error in e_cause.exceptions + assert exc_info.value is user_error + e_context = exc_info.value.__context__ + assert isinstance(e_context, BaseExceptionGroup) # pylint: disable=possibly-used-before-assignment + assert internal_error in e_context.exceptions + assert user_error_context in e_context.exceptions @fail_after(5) async def test_open_websocket_cancellations(nursery, monkeypatch, autojump_clock): @@ -513,6 +518,8 @@ async def handler(request): server_ws = await request.accept() await server_ws.ping(b"a") user_cancelled = None + user_cancelled_cause = None + user_cancelled_context = None server = await nursery.start(serve_websocket, handler, HOST, 0, None) with trio.move_on_after(2): @@ -522,8 +529,13 @@ async def handler(request): await trio.sleep_forever() except trio.Cancelled as e: user_cancelled = e + user_cancelled_cause = e.__cause__ + user_cancelled_context = e.__context__ raise + assert exc_info.value is user_cancelled + assert exc_info.value.__cause__ is user_cancelled_cause + assert exc_info.value.__context__ is user_cancelled_context def _trio_default_non_strict_exception_groups() -> bool: assert re.match(r'^0\.\d\d\.', trio.__version__), "unexpected trio versioning scheme" @@ -560,6 +572,24 @@ async def handler(request): RaisesGroup(ValueError)))).matches(exc.value) +async def test_user_exception_cause(nursery) -> None: + async def handler(request): + await request.accept() + server = await nursery.start(serve_websocket, handler, HOST, 0, None) + e_context = TypeError("foo") + e_primary = ValueError("bar") + e_cause = RuntimeError("zee") + with pytest.raises(ValueError) as exc_info: + async with open_websocket(HOST, server.port, RESOURCE, use_ssl=False): + try: + raise e_context + except TypeError: + raise e_primary from e_cause + e = exc_info.value + assert e is e_primary + assert e.__cause__ is e_cause + assert e.__context__ is e_context + @fail_after(1) async def test_reject_handshake(nursery): async def handler(request): @@ -1176,3 +1206,16 @@ async def server(): async with trio.open_nursery() as nursery: nursery.start_soon(server) nursery.start_soon(client) + + +def test_copy_exceptions(): + # test that exceptions are copy- and pickleable + copy.copy(HandshakeError()) + copy.copy(ConnectionTimeout()) + copy.copy(DisconnectionTimeout()) + assert copy.copy(ConnectionClosed("foo")).reason == "foo" + + rej_copy = copy.copy(ConnectionRejected(404, (("a", "b"),), b"c")) + assert rej_copy.status_code == 404 + assert rej_copy.headers == (("a", "b"),) + assert rej_copy.body == b"c" diff --git a/trio_websocket/_impl.py b/trio_websocket/_impl.py index a71e0be..5f3a9d4 100644 --- a/trio_websocket/_impl.py +++ b/trio_websocket/_impl.py @@ -11,7 +11,7 @@ import ssl import struct import urllib.parse -from typing import Iterable, List, Optional, Union +from typing import Iterable, List, NoReturn, Optional, Union import outcome import trio @@ -151,14 +151,14 @@ async def open_websocket( # yield to user code. If only one of those raise a non-cancelled exception # we will raise that non-cancelled exception. # If we get multiple cancelled, we raise the user's cancelled. - # If both raise exceptions, we raise the user code's exception with the entire - # exception group as the __cause__. + # If both raise exceptions, we raise the user code's exception with __context__ + # set to a group containing internal exception(s) + any user exception __context__ # If we somehow get multiple exceptions, but no user exception, then we raise # TrioWebsocketInternalError. # If closing the connection fails, then that will be raised as the top # exception in the last `finally`. If we encountered exceptions in user code - # or in reader task then they will be set as the `__cause__`. + # or in reader task then they will be set as the `__context__`. async def _open_connection(nursery: trio.Nursery) -> WebSocketConnection: @@ -181,10 +181,27 @@ async def _close_connection(connection: WebSocketConnection) -> None: except trio.TooSlowError: raise DisconnectionTimeout from None + def _raise(exc: BaseException) -> NoReturn: + """This helper allows re-raising an exception without __context__ being set.""" + # cause does not need special handlng, we simply avoid using `raise .. from ..` + __tracebackhide__ = True + context = exc.__context__ + try: + raise exc + finally: + exc.__context__ = context + del exc, context + connection: WebSocketConnection|None=None close_result: outcome.Maybe[None] | None = None user_error = None + # Unwrapping exception groups has a lot of pitfalls, one of them stemming from + # the exception we raise also being inside the group that's set as the context. + # This leads to loss of info unless properly handled. + # See https://github.com/python-trio/flake8-async/issues/298 + # We therefore avoid having the exceptiongroup included as either cause or context + try: async with trio.open_nursery() as new_nursery: result = await outcome.acapture(_open_connection, new_nursery) @@ -205,7 +222,7 @@ async def _close_connection(connection: WebSocketConnection) -> None: except _TRIO_EXC_GROUP_TYPE as e: # user_error, or exception bubbling up from _reader_task if len(e.exceptions) == 1: - raise e.exceptions[0] + _raise(e.exceptions[0]) # contains at most 1 non-cancelled exceptions exception_to_raise: BaseException|None = None @@ -218,25 +235,40 @@ async def _close_connection(connection: WebSocketConnection) -> None: else: if exception_to_raise is None: # all exceptions are cancelled - # prefer raising the one from the user, for traceback reasons + # we reraise the user exception and throw out internal if user_error is not None: - # no reason to raise from e, just to include a bunch of extra - # cancelleds. - raise user_error # pylint: disable=raise-missing-from + _raise(user_error) # multiple internal Cancelled is not possible afaik - raise e.exceptions[0] # pragma: no cover # pylint: disable=raise-missing-from - raise exception_to_raise + # but if so we just raise one of them + _raise(e.exceptions[0]) # pragma: no cover + # raise the non-cancelled exception + _raise(exception_to_raise) - # if we have any KeyboardInterrupt in the group, make sure to raise it. + # if we have any KeyboardInterrupt in the group, raise a new KeyboardInterrupt + # with the group as cause & context for sub_exc in e.exceptions: if isinstance(sub_exc, KeyboardInterrupt): - raise sub_exc from e + raise KeyboardInterrupt from e # Both user code and internal code raised non-cancelled exceptions. - # We "hide" the internal exception(s) in the __cause__ and surface - # the user_error. + # We set the context to be an exception group containing internal exceptions + # and, if not None, `user_error.__context__` if user_error is not None: - raise user_error from e + exceptions = [subexc for subexc in e.exceptions if subexc is not user_error] + eg_substr = '' + # there's technically loss of info here, with __suppress_context__=True you + # still have original __context__ available, just not printed. But we delete + # it completely because we can't partially suppress the group + if user_error.__context__ is not None and not user_error.__suppress_context__: + exceptions.append(user_error.__context__) + eg_substr = ' and the context for the user exception' + eg_str = ( + "Both internal and user exceptions encountered. This group contains " + "the internal exception(s)" + eg_substr + "." + ) + user_error.__context__ = BaseExceptionGroup(eg_str, exceptions) + user_error.__suppress_context__ = False + _raise(user_error) raise TrioWebsocketInternalError( "The trio-websocket API is not expected to raise multiple exceptions. " @@ -576,7 +608,7 @@ def __init__(self, reason): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(reason) self.reason = reason def __repr__(self): @@ -596,7 +628,7 @@ def __init__(self, status_code, headers, body): :param reason: :type reason: CloseReason ''' - super().__init__() + super().__init__(status_code, headers, body) #: a 3 digit HTTP status code self.status_code = status_code #: a tuple of 2-tuples containing header key/value pairs