|
6 | 6 | import queue as stdlib_queue
|
7 | 7 | import threading
|
8 | 8 | from itertools import count
|
9 |
| -from typing import TYPE_CHECKING, Generic, TypeVar |
| 9 | +from typing import TYPE_CHECKING, Generic, TypeVar, Protocol, Final, NoReturn |
10 | 10 |
|
11 | 11 | import attrs
|
12 | 12 | import outcome
|
|
36 | 36 | Ts = TypeVarTuple("Ts")
|
37 | 37 |
|
38 | 38 | RetT = TypeVar("RetT")
|
| 39 | +T_co = TypeVar("T_co", covariant=True) |
39 | 40 |
|
40 | 41 |
|
41 | 42 | class _ParentTaskData(threading.local):
|
@@ -253,6 +254,32 @@ def run_in_system_nursery(self, token: TrioToken) -> None:
|
253 | 254 | token.run_sync_soon(self.run_sync)
|
254 | 255 |
|
255 | 256 |
|
| 257 | +class _SupportsUnwrap(Protocol, Generic[T_co]): |
| 258 | + def unwrap(self) -> T_co: ... |
| 259 | + |
| 260 | + |
| 261 | +class _Value(_SupportsUnwrap[T_co]): |
| 262 | + def __init__(self, v: T_co) -> None: |
| 263 | + self._v: Final = v |
| 264 | + |
| 265 | + def unwrap(self) -> T_co: |
| 266 | + try: |
| 267 | + return self._v |
| 268 | + finally: |
| 269 | + del self._v |
| 270 | + |
| 271 | + |
| 272 | +class _Error(_SupportsUnwrap[NoReturn]): |
| 273 | + def __init__(self, e: BaseException) -> None: |
| 274 | + self._e: Final = e |
| 275 | + |
| 276 | + def unwrap(self) -> NoReturn: |
| 277 | + try: |
| 278 | + raise self._e |
| 279 | + finally: |
| 280 | + del self._e |
| 281 | + |
| 282 | + |
256 | 283 | @enable_ki_protection
|
257 | 284 | async def to_thread_run_sync(
|
258 | 285 | sync_fn: Callable[[Unpack[Ts]], RetT],
|
@@ -372,11 +399,15 @@ def do_release_then_return_result() -> RetT:
|
372 | 399 | try:
|
373 | 400 | return result.unwrap()
|
374 | 401 | finally:
|
| 402 | + del result |
375 | 403 | limiter.release_on_behalf_of(placeholder)
|
376 | 404 |
|
377 | 405 | result = outcome.capture(do_release_then_return_result)
|
| 406 | + if isinstance(result, outcome.Error): |
| 407 | + result2: _SupportsUnwrap[RetT] = _Error(result.error) |
| 408 | + result2 = _Value(result.value) |
378 | 409 | if task_register[0] is not None:
|
379 |
| - trio.lowlevel.reschedule(task_register[0], outcome.Value(result)) |
| 410 | + trio.lowlevel.reschedule(task_register[0], outcome.Value(result2)) |
380 | 411 |
|
381 | 412 | current_trio_token = trio.lowlevel.current_trio_token()
|
382 | 413 |
|
|
0 commit comments