Skip to content

Commit 58e03c9

Browse files
committed
handle multiple cancellations
1 parent ba9c1d2 commit 58e03c9

File tree

4 files changed

+83
-57
lines changed

4 files changed

+83
-57
lines changed

src/trio/_channel.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,11 @@
1919
from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
2020
from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
2121
from ._util import (
22+
MultipleExceptionError,
2223
NoPublicConstructor,
2324
final,
2425
generic_function,
25-
raise_saving_context,
26+
raise_single_exception_from_group,
2627
)
2728

2829
if sys.version_info < (3, 11):
@@ -548,15 +549,15 @@ async def context_manager(
548549
# abandoned generator if it's still alive.
549550
nursery.cancel_scope.cancel()
550551
except BaseExceptionGroup as eg:
551-
first, *rest = eg.exceptions
552-
if rest:
552+
try:
553+
raise_single_exception_from_group(eg)
554+
except MultipleExceptionError:
553555
# In case user has except* we make it possible for them to handle the
554556
# exceptions.
555557
raise BaseExceptionGroup(
556558
"Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.",
557559
[eg],
558560
) from None
559-
raise_saving_context(first)
560561

561562
async def _move_elems_to_channel(
562563
agen: AsyncGenerator[T, None],

src/trio/_tests/test_channel.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -506,19 +506,18 @@ async def agen() -> AsyncGenerator[int]:
506506

507507

508508
async def test_as_safe_channel_genexit_finally() -> None:
509-
events: list[str] = []
510-
511509
@as_safe_channel
512-
async def agen(stuff: list[str]) -> AsyncGenerator[int]:
510+
async def agen(events: list[str]) -> AsyncGenerator[int]:
513511
try:
514512
yield 1
515513
except BaseException as e:
516-
stuff.append(repr(e))
514+
events.append(repr(e))
517515
raise
518516
finally:
519-
stuff.append("finally")
517+
events.append("finally")
520518
raise ValueError("agen")
521519

520+
events: list[str] = []
522521
with RaisesGroup(
523522
RaisesGroup(
524523
Matcher(ValueError, match="^agen$"),
@@ -569,7 +568,7 @@ async def test_as_safe_channel_dont_unwrap_user_exceptiongroup() -> None:
569568
@as_safe_channel
570569
async def agen() -> AsyncGenerator[None]:
571570
raise NotImplementedError("not entered")
572-
yield
571+
yield # pragma: no cover
573572

574573
with RaisesGroup(Matcher(ValueError, match="bar"), match="foo"):
575574
async with agen() as _:
@@ -602,22 +601,27 @@ async def handle_value(
602601

603602
async def test_as_safe_channel_multi_cancel() -> None:
604603
@as_safe_channel
605-
async def agen() -> AsyncGenerator[None]:
604+
async def agen(events: list[str]) -> AsyncGenerator[None]:
606605
try:
607606
yield
608607
finally:
609608
# this will give a warning of ASYNC120, although it's not technically a
610609
# problem of swallowing existing exceptions
611-
await trio.lowlevel.checkpoint()
610+
try:
611+
await trio.lowlevel.checkpoint()
612+
except trio.Cancelled:
613+
events.append("agen cancel")
614+
raise
612615

616+
events: list[str] = []
613617
with trio.CancelScope() as cs:
614-
with RaisesGroup(
615-
RaisesGroup(
616-
trio.Cancelled, trio.Cancelled, match="^Exceptions from Trio nursery$"
617-
),
618-
match=r"^Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.$",
619-
):
620-
async with agen() as recv_chan:
618+
with pytest.raises(trio.Cancelled):
619+
async with agen(events) as recv_chan:
621620
async for _ in recv_chan: # pragma: no branch
622621
cs.cancel()
623-
await trio.lowlevel.checkpoint()
622+
try:
623+
await trio.lowlevel.checkpoint()
624+
except trio.Cancelled:
625+
events.append("body cancel")
626+
raise
627+
assert events == ["body cancel", "agen cancel"]

src/trio/_tests/test_util.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
)
1919
from .._util import (
2020
ConflictDetector,
21+
MultipleExceptionError,
2122
NoPublicConstructor,
2223
coroutine_or_error,
2324
final,
@@ -288,26 +289,28 @@ async def test_raise_single_exception_from_group() -> None:
288289
assert excinfo.value.__cause__ == cause
289290
assert excinfo.value.__context__ == context
290291

291-
with pytest.raises(ValueError, match="foo") as excinfo:
292-
raise_single_exception_from_group(
293-
ExceptionGroup("", [ExceptionGroup("", [exc])])
294-
)
295-
assert excinfo.value.__cause__ == cause
296-
assert excinfo.value.__context__ == context
292+
# only unwraps one layer of exceptiongroup
293+
inner_eg = ExceptionGroup("inner eg", [exc])
294+
inner_cause = SyntaxError("inner eg cause")
295+
inner_context = TypeError("inner eg context")
296+
inner_eg.__cause__ = inner_cause
297+
inner_eg.__context__ = inner_context
298+
with RaisesGroup(Matcher(ValueError, match="^foo$"), match="^inner eg$") as eginfo:
299+
raise_single_exception_from_group(ExceptionGroup("", [inner_eg]))
300+
assert eginfo.value.__cause__ == inner_cause
301+
assert eginfo.value.__context__ == inner_context
297302

298303
with pytest.raises(ValueError, match="foo") as excinfo:
299304
raise_single_exception_from_group(
300-
BaseExceptionGroup(
301-
"", [cancelled, BaseExceptionGroup("", [cancelled, exc])]
302-
)
305+
BaseExceptionGroup("", [cancelled, cancelled, exc])
303306
)
304307
assert excinfo.value.__cause__ == cause
305308
assert excinfo.value.__context__ == context
306309

307310
# multiple non-cancelled
308311
eg = ExceptionGroup("", [ValueError("foo"), ValueError("bar")])
309312
with pytest.raises(
310-
AssertionError,
313+
MultipleExceptionError,
311314
match=r"^Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller.$",
312315
) as excinfo:
313316
raise_single_exception_from_group(eg)
@@ -328,6 +331,20 @@ async def test_raise_single_exception_from_group() -> None:
328331
assert excinfo.value.__cause__ is eg_ki
329332
assert excinfo.value.__context__ is None
330333

334+
# and same for SystemExit
335+
systemexit_ki = BaseExceptionGroup(
336+
"",
337+
[
338+
ValueError("foo"),
339+
ValueError("bar"),
340+
SystemExit("this exc doesn't get reraised"),
341+
],
342+
)
343+
with pytest.raises(SystemExit, match=r"^$") as excinfo:
344+
raise_single_exception_from_group(systemexit_ki)
345+
assert excinfo.value.__cause__ is systemexit_ki
346+
assert excinfo.value.__context__ is None
347+
331348
# if we only got cancelled, first one is reraised
332349
with pytest.raises(trio.Cancelled, match=r"^Cancelled$") as excinfo:
333350
raise_single_exception_from_group(

src/trio/_util.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import collections.abc
55
import inspect
66
import signal
7-
import sys
87
from abc import ABCMeta
98
from collections.abc import Awaitable, Callable, Sequence
109
from functools import update_wrapper
@@ -21,19 +20,20 @@
2120

2221
import trio
2322

24-
if sys.version_info < (3, 11):
25-
from exceptiongroup import BaseExceptionGroup
26-
2723
# Explicit "Any" is not allowed
2824
CallT = TypeVar("CallT", bound=Callable[..., Any]) # type: ignore[explicit-any]
2925
T = TypeVar("T")
3026
RetT = TypeVar("RetT")
3127

3228
if TYPE_CHECKING:
29+
import sys
3330
from types import AsyncGeneratorType, TracebackType
3431

3532
from typing_extensions import ParamSpec, Self, TypeVarTuple, Unpack
3633

34+
if sys.version_info < (3, 11):
35+
from exceptiongroup import BaseExceptionGroup
36+
3737
ArgsT = ParamSpec("ArgsT")
3838
PosArgsT = TypeVarTuple("PosArgsT")
3939

@@ -372,12 +372,19 @@ def raise_saving_context(exc: BaseException) -> NoReturn:
372372
del exc, context
373373

374374

375+
class MultipleExceptionError(Exception):
376+
"""Raised by raise_single_exception_from_group if encountering multiple
377+
non-cancelled exceptions."""
378+
379+
375380
def raise_single_exception_from_group(
376381
eg: BaseExceptionGroup[BaseException],
377382
) -> NoReturn:
378383
"""This function takes an exception group that is assumed to have at most
379384
one non-cancelled exception, which it reraises as a standalone exception.
380385
386+
This exception may be an exceptiongroup itself, in which case it will not be unwrapped.
387+
381388
If a :exc:`KeyboardInterrupt` is encountered, a new KeyboardInterrupt is immediately
382389
raised with the entire group as cause.
383390
@@ -389,30 +396,27 @@ def raise_single_exception_from_group(
389396
If multiple non-cancelled exceptions are encountered, it raises
390397
:exc:`AssertionError`.
391398
"""
392-
cancelled_exceptions = []
393-
noncancelled_exceptions = []
394-
395-
# subgroup/split retains excgroup structure, so we need to manually traverse
396-
def _parse_excg(e: BaseException) -> None:
399+
# immediately bail out if there's any KI or SystemExit
400+
for e in eg.exceptions:
397401
if isinstance(e, (KeyboardInterrupt, SystemExit)):
398-
# immediately bail out
399-
raise KeyboardInterrupt from eg
402+
raise type(e) from eg
403+
404+
cancelled_exception: trio.Cancelled | None = None
405+
noncancelled_exception: BaseException | None = None
400406

407+
for e in eg.exceptions:
401408
if isinstance(e, trio.Cancelled):
402-
cancelled_exceptions.append(e)
403-
elif isinstance(e, BaseExceptionGroup):
404-
for sub_e in e.exceptions:
405-
_parse_excg(sub_e)
409+
if cancelled_exception is None:
410+
cancelled_exception = e
411+
elif noncancelled_exception is None:
412+
noncancelled_exception = e
406413
else:
407-
noncancelled_exceptions.append(e)
408-
409-
_parse_excg(eg)
410-
411-
if len(noncancelled_exceptions) > 1:
412-
raise AssertionError(
413-
"Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller."
414-
) from eg
415-
if len(noncancelled_exceptions) == 1:
416-
raise_saving_context(noncancelled_exceptions[0])
417-
assert cancelled_exceptions, "internal error"
418-
raise_saving_context(cancelled_exceptions[0])
414+
raise MultipleExceptionError(
415+
"Attempted to unwrap exceptiongroup with multiple non-cancelled exceptions. This is often caused by a bug in the caller."
416+
) from eg
417+
418+
if noncancelled_exception is not None:
419+
raise_saving_context(noncancelled_exception)
420+
421+
assert cancelled_exception is not None, "group can't be empty"
422+
raise_saving_context(cancelled_exception)

0 commit comments

Comments
 (0)