Skip to content

Commit 718df5c

Browse files
committed
Clean up handling of Handles
1 parent eff5da8 commit 718df5c

File tree

4 files changed

+145
-212
lines changed

4 files changed

+145
-212
lines changed

trio_asyncio/_async.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import trio
2+
import asyncio
23

34
from ._base import BaseTrioEventLoop
4-
from ._handles import Handle
55

66

77
class TrioEventLoop(BaseTrioEventLoop):
@@ -69,7 +69,7 @@ def stop_me():
6969
if self._stopped.is_set():
7070
waiter.set()
7171
else:
72-
self._queue_handle(Handle(stop_me, (), self, context=None, is_sync=True))
72+
self._queue_handle(asyncio.Handle(stop_me, (), self))
7373
return waiter
7474

7575
def _close(self):

trio_asyncio/_base.py

Lines changed: 37 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import warnings
1010
import concurrent.futures
1111

12-
from ._handles import Handle, TimerHandle
12+
from ._handles import ScopedHandle, AsyncHandle
1313
from ._util import run_aio_future, run_aio_generator
1414
from ._deprecate import deprecated, deprecated_alias
1515
from . import _util
@@ -39,28 +39,6 @@ def clear(self):
3939
pass
4040

4141

42-
def _h_raise(handle, exc):
43-
"""
44-
Convince a handle to raise an error.
45-
46-
trio-asyncio enhanced handles have a method to do this
47-
but asyncio's native handles don't. Thus we need to fudge things.
48-
"""
49-
if hasattr(handle, '_raise'):
50-
handle._raise(exc)
51-
return
52-
53-
def _raise(exc):
54-
raise exc
55-
56-
cb, handle._callback = handle._callback, _raise
57-
ar, handle._args = handle._args, (exc,)
58-
try:
59-
handle._run()
60-
finally:
61-
handle._callback, handle._args = cb, ar
62-
63-
6442
class _TrioSelector(_BaseSelectorImpl):
6543
"""A selector that hooks into a ``TrioEventLoop``.
6644
@@ -241,25 +219,6 @@ async def run_aio_coroutine(self, coro):
241219
finally:
242220
sniffio.current_async_library_cvar.reset(t)
243221

244-
async def __run_trio(self, h):
245-
"""Helper for copying the result of a Trio task to an asyncio future"""
246-
f, proc, *args = h._args
247-
if f.cancelled(): # pragma: no cover
248-
return
249-
try:
250-
with trio.CancelScope() as scope:
251-
h._scope = scope
252-
res = await proc(*args)
253-
if scope.cancelled_caught:
254-
f.cancel()
255-
return
256-
except BaseException as exc:
257-
if not f.cancelled(): # pragma: no branch
258-
f.set_exception(exc)
259-
else:
260-
if not f.cancelled(): # pragma: no branch
261-
f.set_result(res)
262-
263222
def trio_as_future(self, proc, *args):
264223
"""Start a new Trio task to run ``await proc(*args)`` asynchronously.
265224
Return an `asyncio.Future` that will resolve to the value or exception
@@ -292,14 +251,7 @@ def trio_as_future(self, proc, *args):
292251
an `asyncio.Future` which will resolve to the result of the call to *proc*
293252
"""
294253
f = asyncio.Future(loop=self)
295-
h = Handle(
296-
self.__run_trio, (
297-
f,
298-
proc,
299-
) + args, self, context=None, is_sync=None
300-
)
301-
self._queue_handle(h)
302-
f.add_done_callback(h._cb_future_cancel)
254+
self._queue_handle(AsyncHandle(proc, args, self, result_future=f))
303255
return f
304256

305257
def run_trio_task(self, proc, *args):
@@ -314,7 +266,7 @@ def run_trio_task(self, proc, *args):
314266
Returns:
315267
an `asyncio.Handle` which can be used to cancel the background task
316268
"""
317-
return self._queue_handle(Handle(proc, args, self, is_sync=False))
269+
return self._queue_handle(AsyncHandle(proc, args, self))
318270

319271
# Callback handling #
320272

@@ -331,7 +283,7 @@ def _queue_handle(self, handle):
331283
def _call_soon(self, *arks, **kwargs):
332284
raise RuntimeError("_call_soon() should not have been called")
333285

334-
def call_later(self, delay, callback, *args, context=None):
286+
def call_later(self, delay, callback, *args, **context):
335287
"""asyncio's timer-based delay
336288
337289
Note that the callback is a sync function.
@@ -342,36 +294,36 @@ def call_later(self, delay, callback, *args, context=None):
342294
"""
343295
self._check_callback(callback, 'call_later')
344296
assert delay >= 0, delay
345-
h = TimerHandle(delay + self.time(), callback, args, self, context=context, is_sync=True)
297+
h = asyncio.TimerHandle(delay + self.time(), callback, args, self, **context)
346298
self._queue_handle(h)
347299
return h
348300

349-
def call_at(self, when, callback, *args, context=None):
301+
def call_at(self, when, callback, *args, **context):
350302
"""asyncio's time-based delay
351303
352304
Note that the callback is a sync function.
353305
"""
354306
self._check_callback(callback, 'call_at')
355307
return self._queue_handle(
356-
TimerHandle(when, callback, args, self, context=context, is_sync=True)
308+
asyncio.TimerHandle(when, callback, args, self, **context)
357309
)
358310

359-
def call_soon(self, callback, *args, context=None):
311+
def call_soon(self, callback, *args, **context):
360312
"""asyncio's defer-to-mainloop callback executor.
361313
362314
Note that the callback is a sync function.
363315
"""
364316
self._check_callback(callback, 'call_soon')
365-
return self._queue_handle(Handle(callback, args, self, context=context, is_sync=True))
317+
return self._queue_handle(asyncio.Handle(callback, args, self, **context))
366318

367-
def call_soon_threadsafe(self, callback, *args, context=None):
319+
def call_soon_threadsafe(self, callback, *args, **context):
368320
"""asyncio's thread-safe defer-to-mainloop
369321
370322
Note that the callback is a sync function.
371323
"""
372324
self._check_callback(callback, 'call_soon_threadsafe')
373325
self._check_closed()
374-
h = Handle(callback, args, self, context=context, is_sync=True)
326+
h = asyncio.Handle(callback, args, self, **context)
375327
self._token.run_sync_soon(self._q_send.send_nowait, h)
376328

377329
# drop all timers
@@ -471,7 +423,7 @@ async def synchronize(self):
471423
472424
"""
473425
w = trio.Event()
474-
self._queue_handle(Handle(w.set, (), self, is_sync=True))
426+
self._queue_handle(asyncio.Handle(w.set, (), self))
475427
await w.wait()
476428

477429
# Signal handling #
@@ -488,7 +440,7 @@ def add_signal_handler(self, sig, callback, *args):
488440
self._check_signal(sig)
489441
if sig == signal.SIGKILL:
490442
raise RuntimeError("SIGKILL cannot be caught")
491-
h = Handle(callback, args, self, context=None, is_sync=True)
443+
h = asyncio.Handle(callback, args, self)
492444
assert sig not in self._signal_handlers, \
493445
"Signal %d is already being caught" % (sig,)
494446
self._orig_signals[sig] = signal.signal(sig, self._handle_sig)
@@ -528,7 +480,7 @@ def add_reader(self, fd, callback, *args):
528480

529481
def _add_reader(self, fd, callback, *args):
530482
self._check_closed()
531-
handle = Handle(callback, args, self, context=None, is_sync=True)
483+
handle = ScopedHandle(callback, args, self)
532484
reader = self._set_read_handle(fd, handle)
533485
if reader is not None:
534486
reader.cancel()
@@ -547,20 +499,17 @@ def _set_read_handle(self, fd, handle):
547499
self._selector.modify(fd, mask | EVENT_READ, (handle, writer))
548500
return reader
549501

550-
async def _reader_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED):
551-
task_status.started()
552-
with trio.CancelScope() as scope:
553-
handle._scope = scope
502+
async def _reader_loop(self, fd, handle):
503+
with handle._scope:
554504
try:
555-
while not handle._cancelled: # pragma: no branch
505+
while True:
556506
await _wait_readable(fd)
557-
handle._call_sync()
507+
if handle._cancelled:
508+
break
509+
handle._run()
558510
await self.synchronize()
559511
except Exception as exc:
560-
_h_raise(handle, exc)
561-
return
562-
finally:
563-
handle._scope = None
512+
handle._raise(exc)
564513

565514
# writing to a file descriptor
566515

@@ -583,7 +532,7 @@ def add_writer(self, fd, callback, *args):
583532

584533
def _add_writer(self, fd, callback, *args):
585534
self._check_closed()
586-
handle = Handle(callback, args, self, context=None, is_sync=True)
535+
handle = ScopedHandle(callback, args, self)
587536
writer = self._set_write_handle(fd, handle)
588537
if writer is not None:
589538
writer.cancel()
@@ -601,20 +550,17 @@ def _set_write_handle(self, fd, handle):
601550
self._selector.modify(fd, mask | EVENT_WRITE, (reader, handle))
602551
return writer
603552

604-
async def _writer_loop(self, fd, handle, task_status=trio.TASK_STATUS_IGNORED):
605-
with trio.CancelScope() as scope:
606-
handle._scope = scope
607-
task_status.started()
553+
async def _writer_loop(self, fd, handle):
554+
with handle._scope:
608555
try:
609-
while not handle._cancelled: # pragma: no branch
556+
while True:
610557
await _wait_writable(fd)
611-
handle._call_sync()
558+
if handle._cancelled:
559+
break
560+
handle._run()
612561
await self.synchronize()
613562
except Exception as exc:
614-
_h_raise(handle, exc)
615-
return
616-
finally:
617-
handle._scope = None
563+
handle._raise(exc)
618564

619565
def autoclose(self, fd):
620566
"""
@@ -717,7 +663,7 @@ async def _main_loop_one(self, no_wait=False):
717663
# so restart from the beginning.
718664
return
719665

720-
if isinstance(obj, TimerHandle):
666+
if isinstance(obj, asyncio.TimerHandle):
721667
# A TimerHandle is added to the list of timers.
722668
heapq.heappush(self._timers, obj)
723669
return
@@ -732,13 +678,17 @@ async def _main_loop_one(self, no_wait=False):
732678

733679
# Don't go through the expensive nursery dance
734680
# if this is a sync function.
735-
if getattr(obj, '_is_sync', True):
681+
if isinstance(obj, AsyncHandle):
682+
if hasattr(obj, '_context'):
683+
obj._context.run(self._nursery.start_soon, obj._run, name=obj._callback)
684+
else:
685+
self._nursery.start_soon(obj._run, name=obj._callback)
686+
await obj._started.wait()
687+
else:
736688
if hasattr(obj, '_context'):
737689
obj._context.run(obj._callback, *obj._args)
738690
else:
739691
obj._callback(*obj._args)
740-
else:
741-
await self._nursery.start(obj._call_async)
742692

743693
async def _main_loop_exit(self):
744694
"""Finalize the loop. It may not be re-entered."""

0 commit comments

Comments
 (0)