|
1 | 1 | from __future__ import annotations
|
2 | 2 |
|
| 3 | +import sys |
3 | 4 | from collections import OrderedDict, deque
|
| 5 | +from collections.abc import AsyncGenerator, Callable # noqa: TC003 # Needed for Sphinx |
| 6 | +from contextlib import AbstractAsyncContextManager, asynccontextmanager |
| 7 | +from functools import wraps |
4 | 8 | from math import inf
|
5 | 9 | from typing import (
|
6 | 10 | TYPE_CHECKING,
|
|
14 | 18 |
|
15 | 19 | from ._abc import ReceiveChannel, ReceiveType, SendChannel, SendType, T
|
16 | 20 | from ._core import Abort, RaiseCancelT, Task, enable_ki_protection
|
17 |
| -from ._util import NoPublicConstructor, final, generic_function |
| 21 | +from ._util import ( |
| 22 | + MultipleExceptionError, |
| 23 | + NoPublicConstructor, |
| 24 | + final, |
| 25 | + generic_function, |
| 26 | + raise_single_exception_from_group, |
| 27 | +) |
| 28 | + |
| 29 | +if sys.version_info < (3, 11): |
| 30 | + from exceptiongroup import BaseExceptionGroup |
18 | 31 |
|
19 | 32 | if TYPE_CHECKING:
|
20 | 33 | from types import TracebackType
|
21 | 34 |
|
22 |
| - from typing_extensions import Self |
| 35 | + from typing_extensions import ParamSpec, Self |
| 36 | + |
| 37 | + P = ParamSpec("P") |
| 38 | +elif "sphinx" in sys.modules: |
| 39 | + # P needs to exist for Sphinx to parse the type hints successfully. |
| 40 | + try: |
| 41 | + from typing_extensions import ParamSpec |
| 42 | + except ImportError: |
| 43 | + P = ... # This is valid in Callable, though not correct |
| 44 | + else: |
| 45 | + P = ParamSpec("P") |
23 | 46 |
|
24 | 47 |
|
25 | 48 | def _open_memory_channel(
|
@@ -440,3 +463,124 @@ async def aclose(self) -> None:
|
440 | 463 | See `MemoryReceiveChannel.close`."""
|
441 | 464 | self.close()
|
442 | 465 | await trio.lowlevel.checkpoint()
|
| 466 | + |
| 467 | + |
| 468 | +class RecvChanWrapper(ReceiveChannel[T]): |
| 469 | + def __init__( |
| 470 | + self, recv_chan: MemoryReceiveChannel[T], send_semaphore: trio.Semaphore |
| 471 | + ) -> None: |
| 472 | + self._recv_chan = recv_chan |
| 473 | + self._send_semaphore = send_semaphore |
| 474 | + |
| 475 | + async def receive(self) -> T: |
| 476 | + self._send_semaphore.release() |
| 477 | + return await self._recv_chan.receive() |
| 478 | + |
| 479 | + async def aclose(self) -> None: |
| 480 | + await self._recv_chan.aclose() |
| 481 | + |
| 482 | + def __enter__(self) -> Self: |
| 483 | + return self |
| 484 | + |
| 485 | + def __exit__( |
| 486 | + self, |
| 487 | + exc_type: type[BaseException] | None, |
| 488 | + exc_value: BaseException | None, |
| 489 | + traceback: TracebackType | None, |
| 490 | + ) -> None: |
| 491 | + self._recv_chan.close() |
| 492 | + |
| 493 | + |
| 494 | +def as_safe_channel( |
| 495 | + fn: Callable[P, AsyncGenerator[T, None]], |
| 496 | +) -> Callable[P, AbstractAsyncContextManager[ReceiveChannel[T]]]: |
| 497 | + """Decorate an async generator function to make it cancellation-safe. |
| 498 | +
|
| 499 | + The ``yield`` keyword offers a very convenient way to write iterators... |
| 500 | + which makes it really unfortunate that async generators are so difficult |
| 501 | + to call correctly. Yielding from the inside of a cancel scope or a nursery |
| 502 | + to the outside `violates structured concurrency <https://xkcd.com/292/>`_ |
| 503 | + with consequences explained in :pep:`789`. Even then, resource cleanup |
| 504 | + errors remain common (:pep:`533`) unless you wrap every call in |
| 505 | + :func:`~contextlib.aclosing`. |
| 506 | +
|
| 507 | + This decorator gives you the best of both worlds: with careful exception |
| 508 | + handling and a background task we preserve structured concurrency by |
| 509 | + offering only the safe interface, and you can still write your iterables |
| 510 | + with the convenience of ``yield``. For example:: |
| 511 | +
|
| 512 | + @as_safe_channel |
| 513 | + async def my_async_iterable(arg, *, kwarg=True): |
| 514 | + while ...: |
| 515 | + item = await ... |
| 516 | + yield item |
| 517 | +
|
| 518 | + async with my_async_iterable(...) as recv_chan: |
| 519 | + async for item in recv_chan: |
| 520 | + ... |
| 521 | +
|
| 522 | + While the combined async-with-async-for can be inconvenient at first, |
| 523 | + the context manager is indispensable for both correctness and for prompt |
| 524 | + cleanup of resources. |
| 525 | + """ |
| 526 | + # Perhaps a future PEP will adopt `async with for` syntax, like |
| 527 | + # https://coconut.readthedocs.io/en/master/DOCS.html#async-with-for |
| 528 | + |
| 529 | + @asynccontextmanager |
| 530 | + @wraps(fn) |
| 531 | + async def context_manager( |
| 532 | + *args: P.args, **kwargs: P.kwargs |
| 533 | + ) -> AsyncGenerator[trio._channel.RecvChanWrapper[T], None]: |
| 534 | + send_chan, recv_chan = trio.open_memory_channel[T](0) |
| 535 | + try: |
| 536 | + async with trio.open_nursery(strict_exception_groups=True) as nursery: |
| 537 | + agen = fn(*args, **kwargs) |
| 538 | + send_semaphore = trio.Semaphore(0) |
| 539 | + # `nursery.start` to make sure that we will clean up send_chan & agen |
| 540 | + # If this errors we don't close `recv_chan`, but the caller |
| 541 | + # never gets access to it, so that's not a problem. |
| 542 | + await nursery.start( |
| 543 | + _move_elems_to_channel, agen, send_chan, send_semaphore |
| 544 | + ) |
| 545 | + # `async with recv_chan` could eat exceptions, so use sync cm |
| 546 | + with RecvChanWrapper(recv_chan, send_semaphore) as wrapped_recv_chan: |
| 547 | + yield wrapped_recv_chan |
| 548 | + # User has exited context manager, cancel to immediately close the |
| 549 | + # abandoned generator if it's still alive. |
| 550 | + nursery.cancel_scope.cancel() |
| 551 | + except BaseExceptionGroup as eg: |
| 552 | + try: |
| 553 | + raise_single_exception_from_group(eg) |
| 554 | + except MultipleExceptionError: |
| 555 | + # In case user has except* we make it possible for them to handle the |
| 556 | + # exceptions. |
| 557 | + raise BaseExceptionGroup( |
| 558 | + "Encountered exception during cleanup of generator object, as well as exception in the contextmanager body - unable to unwrap.", |
| 559 | + [eg], |
| 560 | + ) from None |
| 561 | + |
| 562 | + async def _move_elems_to_channel( |
| 563 | + agen: AsyncGenerator[T, None], |
| 564 | + send_chan: trio.MemorySendChannel[T], |
| 565 | + send_semaphore: trio.Semaphore, |
| 566 | + task_status: trio.TaskStatus, |
| 567 | + ) -> None: |
| 568 | + # `async with send_chan` will eat exceptions, |
| 569 | + # see https://github.com/python-trio/trio/issues/1559 |
| 570 | + with send_chan: |
| 571 | + try: |
| 572 | + task_status.started() |
| 573 | + while True: |
| 574 | + # wait for receiver to call next on the aiter |
| 575 | + await send_semaphore.acquire() |
| 576 | + try: |
| 577 | + value = await agen.__anext__() |
| 578 | + except StopAsyncIteration: |
| 579 | + return |
| 580 | + # Send the value to the channel |
| 581 | + await send_chan.send(value) |
| 582 | + finally: |
| 583 | + # replace try-finally with contextlib.aclosing once python39 is dropped |
| 584 | + await agen.aclose() |
| 585 | + |
| 586 | + return context_manager |
0 commit comments