Skip to content

Commit fb6a6da

Browse files
committed
avoid holding a reference to Outcome in to_thread_run_sync
1 parent b6813ed commit fb6a6da

File tree

4 files changed

+82
-23
lines changed

4 files changed

+82
-23
lines changed

src/trio/_core/_tests/test_run.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
create_asyncio_future_in_new_loop,
3434
gc_collect_harder,
3535
ignore_coroutine_never_awaited_warnings,
36+
no_other_refs,
3637
restore_unraisablehook,
3738
slow,
3839
)
@@ -2802,17 +2803,6 @@ async def spawn_tasks_in_old_nursery(task_status: _core.TaskStatus[None]) -> Non
28022803
assert RaisesGroup(ValueError, ValueError).matches(excinfo.value.__cause__)
28032804

28042805

2805-
if sys.version_info >= (3, 11):
2806-
2807-
def no_other_refs() -> list[object]:
2808-
return []
2809-
2810-
else:
2811-
2812-
def no_other_refs() -> list[object]:
2813-
return [sys._getframe(1)]
2814-
2815-
28162806
@pytest.mark.skipif(
28172807
sys.implementation.name != "cpython",
28182808
reason="Only makes sense with refcounting GC",

src/trio/_core/_tests/tutil.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,3 +115,14 @@ def check_sequence_matches(seq: Sequence[T], template: Iterable[T | set[T]]) ->
115115
def create_asyncio_future_in_new_loop() -> asyncio.Future[object]:
116116
with closing(asyncio.new_event_loop()) as loop:
117117
return loop.create_future()
118+
119+
120+
if sys.version_info >= (3, 11):
121+
122+
def no_other_refs() -> list[object]:
123+
return []
124+
125+
else:
126+
127+
def no_other_refs() -> list[object]:
128+
return [sys._getframe(1)]

src/trio/_tests/test_threads.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import contextvars
4+
import gc
45
import queue as stdlib_queue
56
import re
67
import sys
@@ -29,7 +30,7 @@
2930
sleep_forever,
3031
)
3132
from .._core._tests.test_ki import ki_self
32-
from .._core._tests.tutil import slow
33+
from .._core._tests.tutil import gc_collect_harder, no_other_refs, slow
3334
from .._threads import (
3435
active_thread_count,
3536
current_default_thread_limiter,
@@ -1141,3 +1142,58 @@ async def wait_no_threads_left() -> None:
11411142
async def test_wait_all_threads_completed_no_threads() -> None:
11421143
await wait_all_threads_completed()
11431144
assert active_thread_count() == 0
1145+
1146+
1147+
@pytest.mark.skipif(
1148+
sys.implementation.name == "pypy",
1149+
reason=(
1150+
"gc.get_referrers is broken on PyPy (see "
1151+
"https://github.com/pypy/pypy/issues/5075)"
1152+
),
1153+
)
1154+
async def test_run_sync_worker_references() -> None:
1155+
class Foo:
1156+
pass
1157+
1158+
def foo(_: Foo) -> Foo:
1159+
return Foo()
1160+
1161+
cvar = contextvars.ContextVar[Foo]("cvar")
1162+
contextval = Foo()
1163+
arg = Foo()
1164+
cvar.set(contextval)
1165+
v = await to_thread_run_sync(foo, arg)
1166+
1167+
cvar.set(Foo())
1168+
gc_collect_harder()
1169+
1170+
assert gc.get_referrers(contextval) == no_other_refs()
1171+
assert gc.get_referrers(foo) == no_other_refs()
1172+
assert gc.get_referrers(arg) == no_other_refs()
1173+
assert gc.get_referrers(v) == no_other_refs()
1174+
1175+
1176+
@pytest.mark.skipif(
1177+
sys.implementation.name == "pypy",
1178+
reason=(
1179+
"gc.get_referrers is broken on PyPy (see "
1180+
"https://github.com/pypy/pypy/issues/5075)"
1181+
),
1182+
)
1183+
async def test_run_sync_workerreferences_exc() -> None:
1184+
1185+
class MyException(Exception):
1186+
pass
1187+
1188+
def throw() -> None:
1189+
raise MyException
1190+
1191+
e = None
1192+
try:
1193+
await to_thread_run_sync(throw)
1194+
except MyException as err:
1195+
e = err
1196+
1197+
gc_collect_harder()
1198+
1199+
assert gc.get_referrers(e) == no_other_refs()

src/trio/_threads.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -443,17 +443,19 @@ def abort(raise_cancel: RaiseCancelT) -> trio.lowlevel.Abort:
443443
msg_from_thread: outcome.Outcome[RetT] | Run[object] | RunSync[object] = (
444444
await trio.lowlevel.wait_task_rescheduled(abort)
445445
)
446-
if isinstance(msg_from_thread, outcome.Outcome):
447-
return msg_from_thread.unwrap()
448-
elif isinstance(msg_from_thread, Run):
449-
await msg_from_thread.run()
450-
elif isinstance(msg_from_thread, RunSync):
451-
msg_from_thread.run_sync()
452-
else: # pragma: no cover, internal debugging guard TODO: use assert_never
453-
raise TypeError(
454-
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
455-
)
456-
del msg_from_thread
446+
try:
447+
if isinstance(msg_from_thread, outcome.Outcome):
448+
return msg_from_thread.unwrap()
449+
elif isinstance(msg_from_thread, Run):
450+
await msg_from_thread.run()
451+
elif isinstance(msg_from_thread, RunSync):
452+
msg_from_thread.run_sync()
453+
else: # pragma: no cover, internal debugging guard TODO: use assert_never
454+
raise TypeError(
455+
f"trio.to_thread.run_sync received unrecognized thread message {msg_from_thread!r}.",
456+
)
457+
finally:
458+
del msg_from_thread
457459

458460

459461
def from_thread_check_cancelled() -> None:

0 commit comments

Comments
 (0)