Skip to content

avoid holding a reference to exception and value in to_thread_run_sync #3229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ python -m uv pip install build
python -m build
wheel_package=$(ls dist/*.whl)
python -m uv pip install "trio @ $wheel_package" -c test-requirements.txt
python -m uv pip install https://github.com/python-trio/outcome/archive/e0f317813a499f1a3629b37c3b8caed72825d9c0.zip
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leftover?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No I've not got the change landed in Outcome

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We're going to leave a pinned commit in ci.sh? This seems sketchy.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, once everything's decided on and merged in I'll update the pyproject.toml and remove this

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If python-trio/outcome#45 (comment) is true then I'd at least like a PR that removes dependence on those, even if it doesn't come with tested guarantees. That way any sort of outcome 2.0 release isn't as annoying.

But I also see this PR doesn't change any .value or .error so I assume we already conform with clear-on-unwrap semantics?


# Actual tests
# expands to 0 != 1 if NO_TEST_REQUIREMENTS is not set, if set the `-0` has no effect
Expand Down
1 change: 1 addition & 0 deletions newsfragments/3229.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Avoid holding refs to result/exception from ``trio.to_thread.run_sync``.
2 changes: 2 additions & 0 deletions src/trio/_core/_entry_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ async def kill_everything( # noqa: RUF029 # await not used
"Internal error: `parent_nursery` should never be `None`",
) from exc # pragma: no cover
parent_nursery.start_soon(kill_everything, exc)
finally:
del sync_fn, args, job

# This has to be carefully written to be safe in the face of new items
# being queued while we iterate, and to do a bounded amount of work on
Expand Down
16 changes: 1 addition & 15 deletions src/trio/_core/_tests/test_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
create_asyncio_future_in_new_loop,
gc_collect_harder,
ignore_coroutine_never_awaited_warnings,
no_other_refs,
restore_unraisablehook,
slow,
)
Expand Down Expand Up @@ -2802,25 +2803,10 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non
assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__)


if sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]


@pytest.mark.skipif(
sys.implementation.name != "cpython",
reason="Only makes sense with refcounting GC",
)
@pytest.mark.xfail(
sys.version_info >= (3, 14),
reason="https://github.com/python/cpython/issues/125603",
)
async def test_ki_protection_doesnt_leave_cyclic_garbage() -> None:
class MyException(Exception):
pass
Expand Down
17 changes: 17 additions & 0 deletions src/trio/_core/_tests/tutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,20 @@ def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) ->
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
with closing(asyncio.new_event_loop()) as loop:
return loop.create_future()


if sys.version_info >= (3, 14):

def no_other_refs() -> list[object]:
gen = sys._getframe(1).f_generator
return [] if gen is None else [gen]

elif sys.version_info >= (3, 11):

def no_other_refs() -> list[object]:
return []

else:

def no_other_refs() -> list[object]:
return [sys._getframe(1)]
2 changes: 2 additions & 0 deletions src/trio/_core/_thread_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,8 @@ def _handle_job(self) -> None:
except BaseException as e:
print("Exception while delivering result of thread", file=sys.stderr)
traceback.print_exception(type(e), e, e.__traceback__)
finally:
del result

def _work(self) -> None:
while True:
Expand Down
55 changes: 54 additions & 1 deletion src/trio/_tests/test_threads.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextvars
import gc
import queue as stdlib_queue
import re
import sys
Expand Down Expand Up @@ -29,7 +30,7 @@
sleep_forever,
)
from .._core._tests.test_ki import ki_self
from .._core._tests.tutil import slow
from .._core._tests.tutil import no_other_refs, slow
from .._threads import (
active_thread_count,
current_default_thread_limiter,
Expand Down Expand Up @@ -1141,3 +1142,55 @@ async def wait_no_threads_left() -> None:
async def test_wait_all_threads_completed_no_threads() -> None:
await wait_all_threads_completed()
assert active_thread_count() == 0


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_worker_references() -> None:
class Foo:
pass

def foo(_: Foo) -> Foo:
return Foo()

cvar = contextvars.ContextVar[Foo]("cvar")
contextval = Foo()
arg = Foo()
cvar.set(contextval)
v = await to_thread_run_sync(foo, arg)

cvar.set(Foo())

assert gc.get_referrers(contextval) == no_other_refs()
assert gc.get_referrers(foo) == no_other_refs()
assert gc.get_referrers(arg) == no_other_refs()
assert gc.get_referrers(v) == no_other_refs()


@pytest.mark.skipif(
sys.implementation.name == "pypy",
reason=(
"gc.get_referrers is broken on PyPy (see "
"https://github.com/pypy/pypy/issues/5075)"
),
)
async def test_run_sync_workerreferences_exc() -> None:

class MyException(Exception):
pass

def throw() -> None:
raise MyException

e = None
try:
await to_thread_run_sync(throw)
except MyException as err:
e = err

assert gc.get_referrers(e) == no_other_refs()
63 changes: 49 additions & 14 deletions src/trio/_threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import queue as stdlib_queue
import threading
from itertools import count
from typing import TYPE_CHECKING, Generic, TypeVar
from typing import TYPE_CHECKING, Final, Generic, NoReturn, Protocol, TypeVar

import attrs
import outcome
Expand Down Expand Up @@ -36,6 +36,7 @@
Ts = TypeVarTuple("Ts")

RetT = TypeVar("RetT")
T_co = TypeVar("T_co", covariant=True)


class _ParentTaskData(threading.local):
Expand Down Expand Up @@ -253,6 +254,32 @@ def run_in_system_nursery(self, token: TrioToken) -> None:
token.run_sync_soon(self.run_sync)


class _SupportsUnwrap(Protocol, Generic[T_co]):
def unwrap(self) -> T_co: ...


class _Value(_SupportsUnwrap[T_co]):
def __init__(self, v: T_co) -> None:
self._v: Final = v

def unwrap(self) -> T_co:
try:
return self._v
finally:
del self._v


class _Error(_SupportsUnwrap[NoReturn]):
def __init__(self, e: BaseException) -> None:
self._e: Final = e

def unwrap(self) -> NoReturn:
try:
raise self._e
finally:
del self._e


@enable_ki_protection
async def to_thread_run_sync(
sync_fn: Callable[[Unpack[Ts]], RetT],
Expand Down Expand Up @@ -363,7 +390,7 @@ async def to_thread_run_sync(

# This function gets scheduled into the Trio run loop to deliver the
# thread's result.
def report_back_in_trio_thread_fn(result: outcome.Outcome[RetT]) -> None:
def report_back_in_trio_thread_fn(result: _SupportsUnwrap[RetT]) -> None:
def do_release_then_return_result() -> RetT:
# release_on_behalf_of is an arbitrary user-defined method, so it
# might raise an error. If it does, we want that error to
Expand All @@ -375,6 +402,12 @@ def do_release_then_return_result() -> RetT:
limiter.release_on_behalf_of(placeholder)

result = outcome.capture(do_release_then_return_result)
if isinstance(result, outcome.Error):
result = _Error(result.error)
elif isinstance(result, outcome.Value):
result = _Value(result.value)
else: # pragma: no cover
raise RuntimeError("invalid outcome")
if task_register[0] is not None:
trio.lowlevel.reschedule(task_register[0], outcome.Value(result))

Expand Down Expand Up @@ -440,20 +473,22 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:

while True:
# wait_task_rescheduled return value cannot be typed
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
msg_from_thread: _Value[RetT] | _Error | Run[object] | RunSync[object] = (
await trio.lowlevel.wait_task_rescheduled(abort)
)
if isinstance(msg_from_thread, outcome.Outcome):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
del msg_from_thread
try:
if isinstance(msg_from_thread, (_Value, _Error)):
return msg_from_thread.unwrap()
elif isinstance(msg_from_thread, Run):
await msg_from_thread.run()
elif isinstance(msg_from_thread, RunSync):
msg_from_thread.run_sync()
else: # pragma: no cover, internal debugging guard TODO: use assert_never
raise TypeError(
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
)
finally:
del msg_from_thread


def from_thread_check_cancelled() -> None:
Expand Down
Loading