Skip to content

Commit 5c5ee08

Browse files
committed
Refactor CABI: reimplement async in terms of Thread abstraction
1 parent 55c3ec6 commit 5c5ee08

File tree

2 files changed

+337
-269
lines changed

2 files changed

+337
-269
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 110 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from __future__ import annotations
88
from dataclasses import dataclass
99
from functools import partial
10-
from typing import Any, Optional, Callable, Awaitable, TypeVar, Generic, Literal
10+
from typing import Any, Optional, Callable, TypeVar, Generic, Literal
1111
from enum import Enum, IntEnum
1212
import math
1313
import struct
@@ -214,18 +214,18 @@ class CanonicalOptions(LiftLowerOptions):
214214

215215
### Runtime State
216216

217-
scheduler = asyncio.Lock()
218-
219217
#### Component Instance State
220218

221219
class ComponentInstance:
220+
store: Store
222221
table: Table
223222
may_leave: bool
224223
no_backpressure: asyncio.Event
225224
num_backpressure_waiters: int
226225
lock: asyncio.Lock
227226

228-
def __init__(self):
227+
def __init__(self, store):
228+
self.store = store
229229
self.table = Table()
230230
self.may_leave = True
231231
self.no_backpressure = asyncio.Event()
@@ -457,10 +457,6 @@ class Cancelled(IntEnum):
457457
FALSE = 0
458458
TRUE = 1
459459

460-
OnStart = Callable[[], list[any]]
461-
OnResolve = Callable[[Optional[list[any]]], None]
462-
OnBlock = Callable[[Awaitable], Awaitable[Cancelled]]
463-
464460
class Task:
465461
class State(Enum):
466462
INITIAL = 1
@@ -473,24 +469,23 @@ class State(Enum):
473469
inst: ComponentInstance
474470
ft: FuncType
475471
supertask: Optional[Task]
476-
on_resolve: OnResolve
477-
on_block: OnBlock
472+
thread: Thread
473+
on_resolve: Callable[[Optional[list[any]]], None]
478474
num_borrows: int
479475
context: ContextLocalStorage
480476

481-
def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
477+
def __init__(self, opts, inst, ft, supertask, thread, on_resolve):
482478
self.state = Task.State.INITIAL
483479
self.opts = opts
484480
self.inst = inst
485481
self.ft = ft
486482
self.supertask = supertask
483+
self.thread = thread
487484
self.on_resolve = on_resolve
488-
self.on_block = on_block
489485
self.num_borrows = 0
490486
self.context = ContextLocalStorage()
491487

492488
async def enter(self):
493-
assert(scheduler.locked())
494489
self.trap_if_on_the_stack(self.inst)
495490
if self.opts.sync or self.opts.callback:
496491
if self.inst.lock.locked():
@@ -530,15 +525,15 @@ async def wait_on(self, awaitable, cancellable = False, for_callback = False) ->
530525
if for_callback:
531526
self.inst.lock.release()
532527

533-
cancelled = await self.on_block(f)
528+
cancelled = await self.thread.suspend(f)
534529
if cancelled and not cancellable:
535-
assert(await self.on_block(f) == Cancelled.FALSE)
530+
assert(await self.thread.suspend(f) == Cancelled.FALSE)
536531

537532
if for_callback:
538533
acquired = asyncio.create_task(self.inst.lock.acquire())
539-
cancelled |= await self.on_block(acquired)
534+
cancelled |= await self.thread.suspend(acquired)
540535
if cancelled:
541-
assert(self.on_block(acquired) == Cancelled.FALSE)
536+
assert(self.thread.suspend(acquired) == Cancelled.FALSE)
542537

543538
if cancelled:
544539
assert(self.state == Task.State.INITIAL)
@@ -551,16 +546,6 @@ async def wait_on(self, awaitable, cancellable = False, for_callback = False) ->
551546
else:
552547
return Cancelled.FALSE
553548

554-
async def call_sync(self, callee, on_start, on_return):
555-
async def sync_on_block(awaitable):
556-
if await self.on_block(awaitable) == Cancelled.TRUE:
557-
assert(self.state == Task.State.INITIAL)
558-
self.state = Task.State.PENDING_CANCEL
559-
assert(await self.on_block(awaitable) == Cancelled.FALSE)
560-
return Cancelled.FALSE
561-
562-
await callee(self, on_start, on_return, sync_on_block)
563-
564549
async def wait_for_event(self, waitable_set, cancellable, for_callback) -> EventTuple:
565550
if self.state == Task.State.PENDING_CANCEL and cancellable:
566551
self.state = Task.State.CANCEL_DELIVERED
@@ -626,18 +611,16 @@ class State(IntEnum):
626611
CANCELLED_BEFORE_RETURNED = 4
627612

628613
state: State
629-
task: Task
614+
thread: Optional[Thread]
630615
lenders: Optional[list[ResourceHandle]]
631-
request_cancel_begin: asyncio.Future
632-
request_cancel_end: asyncio.Future
616+
cancellation_requested: bool
633617

634-
def __init__(self, task):
618+
def __init__(self):
635619
Waitable.__init__(self)
636620
self.state = Subtask.State.STARTING
637-
self.task = task
621+
self.thread = None
638622
self.lenders = []
639-
self.request_cancel_begin = asyncio.Future()
640-
self.request_cancel_end = asyncio.Future()
623+
self.cancellation_requested = False
641624

642625
def resolved(self):
643626
match self.state:
@@ -649,44 +632,6 @@ def resolved(self):
649632
Subtask.State.CANCELLED_BEFORE_RETURNED):
650633
return True
651634

652-
async def request_cancel(self):
653-
assert(not self.cancellation_requested() and not self.resolved())
654-
self.request_cancel_begin.set_result(None)
655-
await self.request_cancel_end
656-
657-
def cancellation_requested(self):
658-
return self.request_cancel_begin.done()
659-
660-
async def call_async(self, callee, on_start, on_resolve):
661-
async def do_call():
662-
await callee(self.task, on_start, on_resolve, async_on_block)
663-
relinquish_control()
664-
665-
async def async_on_block(awaitable):
666-
relinquish_control()
667-
if not self.request_cancel_end.done():
668-
await asyncio.wait([awaitable, self.request_cancel_begin],
669-
return_when = asyncio.FIRST_COMPLETED)
670-
if self.request_cancel_begin.done():
671-
return Cancelled.TRUE
672-
else:
673-
await awaitable
674-
assert(awaitable.done())
675-
await scheduler.acquire()
676-
return Cancelled.FALSE
677-
678-
def relinquish_control():
679-
if not ret.done():
680-
ret.set_result(None)
681-
elif self.request_cancel_begin.done() and not self.request_cancel_end.done():
682-
self.request_cancel_end.set_result(None)
683-
else:
684-
scheduler.release()
685-
686-
ret = asyncio.Future()
687-
asyncio.create_task(do_call())
688-
await ret
689-
690635
def add_lender(self, lending_handle):
691636
assert(not self.resolve_delivered() and not self.resolved())
692637
lending_handle.num_lends += 1
@@ -927,6 +872,84 @@ def drop(self):
927872
trap_if(not self.done)
928873
FutureEnd.drop(self)
929874

875+
#### Thread State
876+
877+
class Thread:
878+
store: Store
879+
future: Optional[asyncio.Future]
880+
on_resume: Optional[asyncio.Future]
881+
on_suspend_or_exit: Optional[asyncio.Future]
882+
returned: bool
883+
884+
def __init__(self, store, lifted_func, caller, on_start, on_resolve):
885+
self.store = store
886+
self.future = None
887+
self.on_resume = asyncio.Future()
888+
self.on_suspend_or_exit = None
889+
self.returned = False
890+
async def async_impl():
891+
assert(await self.on_resume == Cancelled.FALSE)
892+
self.on_resume = None
893+
await lifted_func(caller, self, on_start, on_resolve)
894+
self.on_suspend_or_exit.set_result(None)
895+
self.returned = True
896+
asyncio.create_task(async_impl())
897+
898+
async def resume(self, cancelled = Cancelled.FALSE):
899+
if self.future:
900+
assert(cancelled or self.future.done())
901+
self.future = None
902+
self.store.waiting.remove(self)
903+
self.on_resume.set_result(cancelled)
904+
assert(not self.on_suspend_or_exit)
905+
self.on_suspend_or_exit = asyncio.Future()
906+
await self.on_suspend_or_exit
907+
self.on_suspend_or_exit = None
908+
if self.future:
909+
self.store.waiting.append(self)
910+
911+
async def suspend(self, future) -> Cancelled:
912+
assert(not self.future)
913+
self.future = future
914+
self.on_suspend_or_exit.set_result(None)
915+
self.on_suspend_or_exit = None
916+
assert(not self.on_resume)
917+
self.on_resume = asyncio.Future()
918+
cancelled = await self.on_resume
919+
self.on_resume = None
920+
return cancelled
921+
922+
#### Store State / Embedding API
923+
924+
class Store:
925+
loop: asyncio.AbstractEventLoop
926+
waiting: list[Thread]
927+
928+
def __init__(self):
929+
self.loop = asyncio.new_event_loop()
930+
self.waiting = []
931+
932+
ExportCall = Thread
933+
934+
def start_export_call(self, lifted_func, on_start, on_resolve) -> ExportCall:
935+
async def async_impl():
936+
caller = None
937+
thread = Thread(self, lifted_func, caller, on_start, on_resolve)
938+
await thread.resume()
939+
return thread
940+
return self.loop.run_until_complete(async_impl())
941+
942+
def tick(self):
943+
if not DETERMINISTIC_PROFILE:
944+
random.shuffle(self.waiting)
945+
for thread in self.waiting:
946+
if thread.future.done():
947+
self.loop.run_until_complete(thread.resume())
948+
return
949+
950+
def export_call_finished(self, export_call: ExportCall):
951+
return export_call.returned
952+
930953
### Despecialization
931954

932955
def despecialize(t):
@@ -1882,8 +1905,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
18821905

18831906
### `canon lift`
18841907

1885-
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
1886-
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1908+
async def canon_lift(opts, inst, ft, callee, caller, thread, on_start, on_resolve):
1909+
task = Task(opts, inst, ft, caller, thread, on_resolve)
18871910
if await task.enter() == Cancelled.TRUE:
18881911
task.cancel()
18891912
task.exit()
@@ -1958,7 +1981,7 @@ async def call_and_trap_on_throw(callee, task, args):
19581981

19591982
async def canon_lower(opts, ft, callee, task, flat_args):
19601983
trap_if(not task.inst.may_leave)
1961-
subtask = Subtask(task)
1984+
subtask = Subtask()
19621985

19631986
cx = LiftLowerContext(opts, task.inst, subtask)
19641987
flat_ft = flatten_functype(opts, ft, 'lower')
@@ -1984,7 +2007,7 @@ def on_start():
19842007
def on_resolve(result):
19852008
on_progress()
19862009
if result is None:
1987-
assert(subtask.cancellation_requested())
2010+
assert(subtask.cancellation_requested)
19882011
if subtask.state == Subtask.State.STARTING:
19892012
subtask.state = Subtask.State.CANCELLED_BEFORE_STARTED
19902013
else:
@@ -1996,13 +2019,19 @@ def on_resolve(result):
19962019
nonlocal flat_results
19972020
flat_results = lower_flat_values(cx, max_flat_results, result, ft.result_type(), flat_args)
19982021

2022+
subtask.thread = Thread(task.inst.store, callee, task, on_start, on_resolve)
2023+
await subtask.thread.resume()
2024+
19992025
if opts.sync:
2000-
await task.call_sync(callee, on_start, on_resolve)
2026+
if not subtask.resolved():
2027+
done = asyncio.Event()
2028+
def on_progress():
2029+
done.set()
2030+
await task.wait_on(done.wait())
20012031
assert(types_match_values(flat_ft.results, flat_results))
20022032
subtask.deliver_resolve()
20032033
return flat_results
20042034
else:
2005-
await subtask.call_async(callee, on_start, on_resolve)
20062035
if subtask.resolved():
20072036
assert(flat_results == [])
20082037
subtask.deliver_resolve()
@@ -2182,11 +2211,12 @@ async def canon_subtask_cancel(sync, task, i):
21822211
subtask = task.inst.table.get(i)
21832212
trap_if(not isinstance(subtask, Subtask))
21842213
trap_if(subtask.resolve_delivered())
2185-
trap_if(subtask.cancellation_requested())
2214+
trap_if(subtask.cancellation_requested)
21862215
if subtask.resolved():
21872216
assert(subtask.has_pending_event())
21882217
else:
2189-
await subtask.request_cancel()
2218+
subtask.cancellation_requested = True
2219+
await subtask.thread.resume(Cancelled.TRUE)
21902220
if sync:
21912221
while not subtask.resolved():
21922222
if subtask.has_pending_event():

0 commit comments

Comments
 (0)