Skip to content

Commit 78cffd0

Browse files
committed
Replace 'sync' with 'cancellable' in wait/poll/yield
Replace instance-suspension semantics of sync lower with instance-locking semantics in lift
1 parent 00f31f2 commit 78cffd0

File tree

3 files changed

+374
-137
lines changed

3 files changed

+374
-137
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 87 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221221
class ComponentInstance:
222222
table: Table
223223
may_leave: bool
224-
backpressure: bool
225-
calling_sync_export: bool
226-
calling_sync_import: bool
227-
pending_tasks: list[tuple[Task, asyncio.Future]]
228-
starting_pending_task: bool
229-
async_waiting_tasks: asyncio.Condition
224+
no_backpressure: asyncio.Event
225+
num_backpressure_waiters: int
226+
lock: asyncio.Lock
230227

231228
def __init__(self):
232229
self.table = Table()
233230
self.may_leave = True
234-
self.backpressure = False
235-
self.calling_sync_export = False
236-
self.calling_sync_import = False
237-
self.pending_tasks = []
238-
self.starting_pending_task = False
239-
self.async_waiting_tasks = asyncio.Condition(scheduler)
231+
self.no_backpressure = asyncio.Event()
232+
self.no_backpressure.set()
233+
self.num_backpressure_waiters = 0
234+
self.lock = asyncio.Lock()
240235

241236
#### Table State
242237

@@ -497,67 +492,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497492
async def enter(self):
498493
assert(scheduler.locked())
499494
self.trap_if_on_the_stack(self.inst)
500-
if not self.may_enter(self) or self.inst.pending_tasks:
501-
f = asyncio.Future()
502-
self.inst.pending_tasks.append((self, f))
503-
if await self.on_block(f) == Cancelled.TRUE:
504-
[i] = [i for i,(t,_) in enumerate(self.inst.pending_tasks) if t == self]
505-
self.inst.pending_tasks.pop(i)
506-
self.on_resolve(None)
507-
return Cancelled.FALSE
508-
assert(self.may_enter(self) and self.inst.starting_pending_task)
509-
self.inst.starting_pending_task = False
510-
if self.opts.sync:
511-
self.inst.calling_sync_export = True
512-
return True
495+
if self.opts.sync or self.opts.callback:
496+
if self.inst.lock.locked():
497+
acquired = asyncio.create_task(self.inst.lock.acquire())
498+
cancelled = await self.wait_on(acquired, cancellable = True, for_callback = False)
499+
if cancelled:
500+
if acquired.done():
501+
self.inst.lock.release()
502+
else:
503+
acquired.cancel()
504+
return Cancelled.TRUE
505+
else:
506+
await self.inst.lock.acquire()
507+
if not self.inst.no_backpressure.is_set() or self.inst.num_backpressure_waiters > 0:
508+
while True:
509+
self.inst.num_backpressure_waiters += 1
510+
maybe_go = self.inst.no_backpressure.wait()
511+
cancelled = await self.wait_on(maybe_go, cancellable = True, for_callback = False)
512+
self.inst.num_backpressure_waiters -= 1
513+
if cancelled:
514+
return Cancelled.TRUE
515+
if self.inst.no_backpressure.is_set():
516+
break
517+
return Cancelled.FALSE
513518

514519
def trap_if_on_the_stack(self, inst):
515520
c = self.supertask
516521
while c is not None:
517522
trap_if(c.inst is inst)
518523
c = c.supertask
519524

520-
def may_enter(self, pending_task):
521-
return not self.inst.backpressure and \
522-
not self.inst.calling_sync_import and \
523-
not (self.inst.calling_sync_export and pending_task.opts.sync)
524-
525-
def maybe_start_pending_task(self):
526-
if self.inst.starting_pending_task:
527-
return
528-
for i,(pending_task,pending_future) in enumerate(self.inst.pending_tasks):
529-
if self.may_enter(pending_task):
530-
self.inst.pending_tasks.pop(i)
531-
self.inst.starting_pending_task = True
532-
pending_future.set_result(None)
533-
return
525+
async def wait_on(self, awaitable, cancellable = False, for_callback = False) -> Cancelled:
526+
f = asyncio.ensure_future(awaitable)
527+
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
528+
return Cancelled.FALSE
534529

535-
async def wait_on(self, awaitable, sync, cancellable = False) -> bool:
536-
if sync:
537-
assert(not self.inst.calling_sync_import)
538-
self.inst.calling_sync_import = True
539-
else:
540-
self.maybe_start_pending_task()
530+
if for_callback:
531+
self.inst.lock.release()
541532

542-
awaitable = asyncio.ensure_future(awaitable)
543-
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
544-
cancelled = Cancelled.FALSE
545-
else:
546-
cancelled = await self.on_block(awaitable)
547-
if cancelled and not cancellable:
548-
assert(self.state == Task.State.INITIAL)
549-
self.state = Task.State.PENDING_CANCEL
550-
cancelled = await self.on_block(awaitable)
551-
assert(not cancelled)
533+
cancelled = await self.on_block(f)
534+
if cancelled and not cancellable:
535+
assert(await self.on_block(f) == Cancelled.FALSE)
552536

553-
if sync:
554-
self.inst.calling_sync_import = False
555-
self.inst.async_waiting_tasks.notify_all()
556-
else:
557-
while self.inst.calling_sync_import:
558-
await self.inst.async_waiting_tasks.wait()
537+
if for_callback:
538+
acquired = asyncio.create_task(self.inst.lock.acquire())
539+
cancelled |= await self.on_block(acquired)
540+
if cancelled:
541+
assert(self.on_block(acquired) == Cancelled.FALSE)
559542

560-
return cancelled
543+
if cancelled:
544+
assert(self.state == Task.State.INITIAL)
545+
if not cancellable:
546+
self.state = Task.State.PENDING_CANCEL
547+
return Cancelled.FALSE
548+
else:
549+
self.state = Task.State.CANCEL_DELIVERED
550+
return Cancelled.TRUE
551+
else:
552+
return Cancelled.FALSE
561553

562554
async def call_sync(self, callee, on_start, on_return):
563555
async def sync_on_block(awaitable):
@@ -567,42 +559,36 @@ async def sync_on_block(awaitable):
567559
assert(await self.on_block(awaitable) == Cancelled.FALSE)
568560
return Cancelled.FALSE
569561

570-
assert(not self.inst.calling_sync_import)
571-
self.inst.calling_sync_import = True
572562
await callee(self, on_start, on_return, sync_on_block)
573-
self.inst.calling_sync_import = False
574-
self.inst.async_waiting_tasks.notify_all()
575563

576-
async def wait_for_event(self, waitable_set, sync) -> EventTuple:
577-
if self.state == Task.State.PENDING_CANCEL:
564+
async def wait_for_event(self, waitable_set, cancellable, for_callback) -> EventTuple:
565+
if self.state == Task.State.PENDING_CANCEL and cancellable:
578566
self.state = Task.State.CANCEL_DELIVERED
579567
return (EventCode.TASK_CANCELLED, 0, 0)
580568
else:
581569
waitable_set.num_waiting += 1
582570
e = None
583571
while not e:
584572
maybe_event = waitable_set.maybe_has_pending_event.wait()
585-
if await self.wait_on(maybe_event, sync, cancellable = True):
586-
assert(self.state == Task.State.INITIAL)
587-
self.state = Task.State.CANCEL_DELIVERED
573+
if await self.wait_on(maybe_event, cancellable, for_callback) == Cancelled.TRUE:
588574
return (EventCode.TASK_CANCELLED, 0, 0)
589575
e = waitable_set.poll()
590576
waitable_set.num_waiting -= 1
591577
return e
592578

593-
async def yield_(self, sync) -> EventTuple:
594-
if self.state == Task.State.PENDING_CANCEL:
579+
async def yield_(self, cancellable, for_callback) -> EventTuple:
580+
if self.state == Task.State.PENDING_CANCEL and cancellable:
595581
self.state = Task.State.CANCEL_DELIVERED
596582
return (EventCode.TASK_CANCELLED, 0, 0)
597-
elif await self.wait_on(asyncio.sleep(0), sync, cancellable = True):
598-
assert(self.state == Task.State.INITIAL)
599-
self.state = Task.State.CANCEL_DELIVERED
583+
elif await self.wait_on(asyncio.sleep(0), cancellable, for_callback) == Cancelled.TRUE:
600584
return (EventCode.TASK_CANCELLED, 0, 0)
601585
else:
602586
return (EventCode.NONE, 0, 0)
603587

604-
async def poll_for_event(self, waitable_set, sync) -> Optional[EventTuple]:
605-
event_code,_,_ = e = await self.yield_(sync)
588+
async def poll_for_event(self, waitable_set, cancellable, for_callback) -> Optional[EventTuple]:
589+
waitable_set.num_waiting += 1
590+
event_code,_,_ = e = await self.yield_(cancellable, for_callback)
591+
waitable_set.num_waiting -= 1
606592
if event_code == EventCode.TASK_CANCELLED:
607593
return e
608594
elif (e := waitable_set.poll()):
@@ -624,13 +610,10 @@ def cancel(self):
624610
self.state = Task.State.RESOLVED
625611

626612
def exit(self):
627-
assert(scheduler.locked())
628613
trap_if(self.state != Task.State.RESOLVED)
629614
assert(self.num_borrows == 0)
630-
if self.opts.sync:
631-
assert(self.inst.calling_sync_export)
632-
self.inst.calling_sync_export = False
633-
self.maybe_start_pending_task()
615+
if self.opts.sync or self.opts.callback:
616+
self.inst.lock.release()
634617

635618
#### Subtask State
636619

@@ -1932,7 +1915,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321915

19331916
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
19341917
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1935-
if not await task.enter():
1918+
if await task.enter() == Cancelled.TRUE:
1919+
task.cancel()
1920+
task.exit()
19361921
return
19371922

19381923
cx = LiftLowerContext(opts, inst, task)
@@ -1967,15 +1952,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671952
task.exit()
19681953
return
19691954
case CallbackCode.YIELD:
1970-
e = await task.yield_(sync = False)
1955+
e = await task.yield_(cancellable = True, for_callback = True)
19711956
case CallbackCode.WAIT:
19721957
s = task.inst.table.get(si)
19731958
trap_if(not isinstance(s, WaitableSet))
1974-
e = await task.wait_for_event(s, sync = False)
1959+
e = await task.wait_for_event(s, cancellable = True, for_callback = True)
19751960
case CallbackCode.POLL:
19761961
s = task.inst.table.get(si)
19771962
trap_if(not isinstance(s, WaitableSet))
1978-
e = await task.poll_for_event(s, sync = False)
1963+
e = await task.poll_for_event(s, cancellable = True, for_callback = True)
19791964
event_code, p1, p2 = e
19801965
[packed] = await call_and_trap_on_throw(opts.callback, task, [event_code, p1, p2])
19811966

@@ -2114,8 +2099,11 @@ async def canon_context_set(t, i, task, v):
21142099
### 🔀 `canon backpressure.set`
21152100

21162101
async def canon_backpressure_set(task, flat_args):
2117-
trap_if(task.opts.sync)
2118-
task.inst.backpressure = bool(flat_args[0])
2102+
assert(len(flat_args) == 1)
2103+
if flat_args[0] == 0:
2104+
task.inst.no_backpressure.set()
2105+
else:
2106+
task.inst.no_backpressure.clear()
21192107
return []
21202108

21212109
### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
21402128

21412129
### 🔀 `canon yield`
21422130

2143-
async def canon_yield(sync, task):
2131+
async def canon_yield(cancellable, task):
21442132
trap_if(not task.inst.may_leave)
2145-
event_code,_,_ = await task.yield_(sync)
2133+
event_code,_,_ = await task.yield_(cancellable, for_callback = False)
21462134
match event_code:
21472135
case EventCode.NONE:
21482136
return [0]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
21572145

21582146
### 🔀 `canon waitable-set.wait`
21592147

2160-
async def canon_waitable_set_wait(sync, mem, task, si, ptr):
2148+
async def canon_waitable_set_wait(cancellable, mem, task, si, ptr):
21612149
trap_if(not task.inst.may_leave)
21622150
s = task.inst.table.get(si)
21632151
trap_if(not isinstance(s, WaitableSet))
2164-
e = await task.wait_for_event(s, sync)
2152+
e = await task.wait_for_event(s, cancellable, for_callback = False)
21652153
return unpack_event(mem, task, ptr, e)
21662154

21672155
def unpack_event(mem, task, ptr, e: EventTuple):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732161

21742162
### 🔀 `canon waitable-set.poll`
21752163

2176-
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
2164+
async def canon_waitable_set_poll(cancellable, mem, task, si, ptr):
21772165
trap_if(not task.inst.may_leave)
21782166
s = task.inst.table.get(si)
21792167
trap_if(not isinstance(s, WaitableSet))
2180-
e = await task.poll_for_event(s, sync)
2168+
e = await task.poll_for_event(s, cancellable, for_callback = False)
21812169
return unpack_event(mem, task, ptr, e)
21822170

21832171
### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
22202208
while not subtask.resolved():
22212209
if subtask.has_pending_event():
22222210
_ = subtask.get_event()
2223-
await task.wait_on(subtask.wait_for_pending_event(), sync = True)
2211+
await task.wait_on(subtask.wait_for_pending_event())
22242212
else:
22252213
if not subtask.resolved():
22262214
return [BLOCKED]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
22962284
e.copy(task.inst, buffer, on_copy, on_copy_done)
22972285

22982286
if opts.sync and not e.has_pending_event():
2299-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2287+
await task.wait_on(e.wait_for_pending_event())
23002288

23012289
if e.has_pending_event():
23022290
code,index,payload = e.get_event()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
23422330
e.copy(task.inst, buffer, on_copy_done)
23432331

23442332
if opts.sync and not e.has_pending_event():
2345-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2333+
await task.wait_on(e.wait_for_pending_event())
23462334

23472335
if e.has_pending_event():
23482336
code,index,payload = e.get_event()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752363
e.shared.cancel()
23762364
if not e.has_pending_event():
23772365
if sync:
2378-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2366+
await task.wait_on(e.wait_for_pending_event())
23792367
else:
23802368
return [BLOCKED]
23812369
code,index,payload = e.get_event()

0 commit comments

Comments
 (0)