Skip to content

Commit 26d4b02

Browse files
committed
[trio.from_thread]
Rework thread local storage to use the canonical threading.local() accessor, update from_thread.* unit tests to better reflect use cases, updated no token error msg to give a specific reason for failure.
1 parent 7f36105 commit 26d4b02

File tree

2 files changed

+38
-36
lines changed

2 files changed

+38
-36
lines changed

trio/_threads.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
"BlockingTrioPortal",
1717
]
1818

19+
# Global due to Threading API, thread local storage for trio token
20+
TOKEN_LOCAL = threading.local()
21+
1922

2023
class BlockingTrioPortal:
2124
"""A portal that synchronous threads can reach through to run code in the
@@ -375,23 +378,31 @@ def do_release_then_return_result():
375378

376379
# This is the function that runs in the worker thread to do the actual
377380
# work and then schedule the call to report_back_in_trio_thread_fn
378-
def worker_thread_fn():
381+
# Since this is spawned in a new thread, the trio token needs to be passed
382+
# explicitly to it so it can inject it into thread local storage
383+
def worker_thread_fn(trio_token):
384+
TOKEN_LOCAL.token = trio_token
379385
result = outcome.capture(sync_fn, *args)
380386
try:
381387
token.run_sync_soon(report_back_in_trio_thread_fn, result)
382388
except trio.RunFinishedError:
383389
# The entire run finished, so our particular task is certainly
384390
# long gone -- it must have cancelled.
385391
pass
392+
finally:
393+
del TOKEN_LOCAL.token
386394

387395
await limiter.acquire_on_behalf_of(placeholder)
388396
try:
389397
# daemon=True because it might get left behind if we cancel, and in
390398
# this case shouldn't block process exit.
399+
current_trio_token = trio.hazmat.current_trio_token()
391400
thread = threading.Thread(
392-
target=worker_thread_fn, name=name, daemon=True
401+
target=worker_thread_fn,
402+
args=(current_trio_token,),
403+
name=name,
404+
daemon=True
393405
)
394-
setattr(thread, 'current_trio_token', trio.hazmat.current_trio_token())
395406
thread.start()
396407
except:
397408
limiter.release_on_behalf_of(placeholder)
@@ -412,11 +423,15 @@ def _run_fn_as_system_task(cb, fn, *args, trio_token=None):
412423
413424
Since this internally uses TrioToken.run_sync_soon, all warnings about
414425
raised exceptions canceling all tasks should be noted.
415-
416426
"""
427+
417428
if not trio_token:
418-
current_thread = threading.current_thread()
419-
trio_token = getattr(current_thread, 'current_trio_token')
429+
try:
430+
trio_token = TOKEN_LOCAL.token
431+
except AttributeError:
432+
raise RuntimeError(
433+
"this thread wasn't created by Trio, pass kwarg trio_token=..."
434+
)
420435

421436
try:
422437
trio.hazmat.current_task()

trio/tests/test_threads.py

Lines changed: 17 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,8 @@ def bad_start(self):
463463
async def test_trio_run_sync_in_thread_token():
464464
# Test that run_sync_in_thread automatically injects the current trio token
465465
# into a spawned thread
466-
467466
def thread_fn():
468-
current_thread = threading.current_thread()
469-
callee_token = getattr(current_thread, 'current_trio_token')
467+
callee_token = run_sync(_core.current_trio_token)
470468
return callee_token
471469

472470
caller_token = _core.current_trio_token()
@@ -478,25 +476,28 @@ async def test_trio_from_thread_run_sync():
478476
# Test that run_sync_in_thread correctly "hands off" the trio token to
479477
# trio.from_thread.run_sync()
480478
def thread_fn():
481-
start = run_sync(_core.current_time)
482-
end = run_sync(_core.current_time)
483-
return end - start
479+
trio_time = run_sync(_core.current_time)
480+
return trio_time
484481

485-
duration = await run_sync_in_thread(thread_fn)
486-
assert duration > 0
482+
trio_time = await run_sync_in_thread(thread_fn)
483+
assert isinstance(trio_time, float)
487484

488485

489486
async def test_trio_from_thread_run():
490487
# Test that run_sync_in_thread correctly "hands off" the trio token to
491488
# trio.from_thread.run()
489+
record = []
490+
491+
async def back_in_trio_fn():
492+
_core.current_time() # implicitly checks that we're in trio
493+
record.append("back in trio")
494+
492495
def thread_fn():
493-
start = time.perf_counter()
494-
run(sleep, 0.05)
495-
end = time.perf_counter()
496-
return end - start
496+
record.append("in thread")
497+
run(back_in_trio_fn)
497498

498-
duration = await run_sync_in_thread(thread_fn)
499-
assert duration > 0
499+
await run_sync_in_thread(thread_fn)
500+
assert record == ["in thread", "back in trio"]
500501

501502

502503
async def test_trio_from_thread_token():
@@ -523,23 +524,9 @@ def thread_fn(token):
523524
assert callee_token == caller_token
524525

525526

526-
async def test_trio_from_thread_both_run():
527-
# Test that trio.from_thread.run() and from_thread.run_sync() can run in
528-
# the same thread together
529-
530-
def thread_fn():
531-
start = run_sync(_core.current_time)
532-
run(sleep, 0.05)
533-
end = run_sync(_core.current_time)
534-
return end - start
535-
536-
duration = await run_sync_in_thread(thread_fn)
537-
assert duration > 0
538-
539-
540-
async def test_trio_from_thread_raw_call():
527+
async def test_from_thread_no_token():
541528
# Test that a "raw call" to trio.from_thread.run() fails because no token
542529
# has been provided
543530

544-
with pytest.raises(AttributeError):
531+
with pytest.raises(RuntimeError):
545532
run_sync(_core.current_time)

0 commit comments

Comments
 (0)