Skip to content

Commit 7aa905f

Browse files
committed
[trio.from_thread]
Add back in some unit tests for the legacy `BlockingTrioPortal` to ensure that the new `trio.from_thread` can handle being called from `portal.run()` and `portal.run_sync()`.
1 parent 84c5ad3 commit 7aa905f

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

trio/_threads.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import trio
99

1010
from ._sync import CapacityLimiter
11-
from ._core import enable_ki_protection, disable_ki_protection, RunVar
11+
from ._core import enable_ki_protection, disable_ki_protection, RunVar, _entry_queue
1212

1313
__all__ = [
1414
"run_sync_in_thread",
@@ -332,6 +332,9 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None):
332332
raised exceptions canceling all tasks should be noted.
333333
"""
334334

335+
if trio_token and not isinstance(trio_token, _entry_queue.TrioToken):
336+
raise RuntimeError("Passed kwarg trio_token is not of type TrioToken")
337+
335338
if not trio_token:
336339
try:
337340
trio_token = TOKEN_LOCAL.token
@@ -340,6 +343,10 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None):
340343
"this thread wasn't created by Trio, pass kwarg trio_token=..."
341344
)
342345

346+
# TODO: This is only necessary for compatibility with BlockingTrioPortal.
347+
# once that is deprecated, this check should no longer be necessary because
348+
# thread local storage (or the absence of) is sufficient to check if trio
349+
# is running in a thread or not.
343350
try:
344351
trio.hazmat.current_task()
345352
except RuntimeError:

trio/tests/test_threads.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,3 +515,40 @@ async def test_from_thread_no_token():
515515

516516
with pytest.raises(RuntimeError):
517517
run_sync(_core.current_time)
518+
519+
520+
def test_run_fn_as_system_task_catched_badly_typed_token():
521+
with pytest.raises(RuntimeError):
522+
run_sync(_core.current_time, trio_token="Not TrioTokentype")
523+
524+
525+
async def test_do_in_trio_thread_from_trio_thread():
526+
# This check specifically confirms that a RuntimeError will be raised if
527+
# the old BlockingTrIoPortal API calls into a trio loop while already
528+
# running inside of one.
529+
portal = BlockingTrioPortal()
530+
531+
with pytest.raises(RuntimeError):
532+
portal.run_sync(lambda: None) # pragma: no branch
533+
534+
async def foo(): # pragma: no cover
535+
pass
536+
537+
with pytest.raises(RuntimeError):
538+
portal.run(foo)
539+
540+
541+
async def test_BlockingTrioPortal_with_explicit_TrioToken():
542+
# This tests the deprecated BlockingTrioPortal with a token passed in to
543+
# confirm that both methods of making a portal are supported by
544+
# trio.from_thread
545+
token = _core.current_trio_token()
546+
547+
def worker_thread(token):
548+
with pytest.raises(RuntimeError):
549+
BlockingTrioPortal()
550+
portal = BlockingTrioPortal(token)
551+
return portal.run_sync(threading.current_thread)
552+
553+
t = await run_sync_in_thread(worker_thread, token)
554+
assert t == threading.current_thread()

0 commit comments

Comments
 (0)