Skip to content

Commit 388a07f

Browse files
committed
Replace 'sync' with 'cancellable' in wait/poll/yield
Replace instance-suspension semantics of sync lower with instance-locking semantics in lift
1 parent 00f31f2 commit 388a07f

File tree

3 files changed

+378
-141
lines changed

3 files changed

+378
-141
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 91 additions & 103 deletions
Original file line numberDiff line numberDiff line change
@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221221
class ComponentInstance:
222222
table: Table
223223
may_leave: bool
224-
backpressure: bool
225-
calling_sync_export: bool
226-
calling_sync_import: bool
227-
pending_tasks: list[tuple[Task, asyncio.Future]]
228-
starting_pending_task: bool
229-
async_waiting_tasks: asyncio.Condition
224+
no_backpressure: asyncio.Event
225+
num_backpressure_waiters: int
226+
lock: asyncio.Lock
230227

231228
def __init__(self):
232229
self.table = Table()
233230
self.may_leave = True
234-
self.backpressure = False
235-
self.calling_sync_export = False
236-
self.calling_sync_import = False
237-
self.pending_tasks = []
238-
self.starting_pending_task = False
239-
self.async_waiting_tasks = asyncio.Condition(scheduler)
231+
self.no_backpressure = asyncio.Event()
232+
self.no_backpressure.set()
233+
self.num_backpressure_waiters = 0
234+
self.lock = asyncio.Lock()
240235

241236
#### Table State
242237

@@ -464,7 +459,7 @@ class Cancelled(IntEnum):
464459

465460
OnStart = Callable[[], list[any]]
466461
OnResolve = Callable[[Optional[list[any]]], None]
467-
OnBlock = Callable[[Awaitable], Awaitable[Cancelled]]
462+
OnBlock = Callable[[asyncio.Future], Awaitable[Cancelled]]
468463

469464
class Task:
470465
class State(Enum):
@@ -494,70 +489,65 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
494489
self.num_borrows = 0
495490
self.context = ContextLocalStorage()
496491

497-
async def enter(self):
498-
assert(scheduler.locked())
499-
self.trap_if_on_the_stack(self.inst)
500-
if not self.may_enter(self) or self.inst.pending_tasks:
501-
f = asyncio.Future()
502-
self.inst.pending_tasks.append((self, f))
503-
if await self.on_block(f) == Cancelled.TRUE:
504-
[i] = [i for i,(t,_) in enumerate(self.inst.pending_tasks) if t == self]
505-
self.inst.pending_tasks.pop(i)
506-
self.on_resolve(None)
507-
return Cancelled.FALSE
508-
assert(self.may_enter(self) and self.inst.starting_pending_task)
509-
self.inst.starting_pending_task = False
510-
if self.opts.sync:
511-
self.inst.calling_sync_export = True
512-
return True
513-
514492
def trap_if_on_the_stack(self, inst):
515493
c = self.supertask
516494
while c is not None:
517495
trap_if(c.inst is inst)
518496
c = c.supertask
519497

520-
def may_enter(self, pending_task):
521-
return not self.inst.backpressure and \
522-
not self.inst.calling_sync_import and \
523-
not (self.inst.calling_sync_export and pending_task.opts.sync)
524-
525-
def maybe_start_pending_task(self):
526-
if self.inst.starting_pending_task:
527-
return
528-
for i,(pending_task,pending_future) in enumerate(self.inst.pending_tasks):
529-
if self.may_enter(pending_task):
530-
self.inst.pending_tasks.pop(i)
531-
self.inst.starting_pending_task = True
532-
pending_future.set_result(None)
533-
return
498+
async def enter(self):
499+
if self.opts.sync or self.opts.callback:
500+
if self.inst.lock.locked():
501+
acquired = asyncio.create_task(self.inst.lock.acquire())
502+
cancelled = await self.block_on(acquired, cancellable = True)
503+
if cancelled:
504+
if acquired.done():
505+
self.inst.lock.release()
506+
else:
507+
acquired.cancel()
508+
return Cancelled.TRUE
509+
else:
510+
await self.inst.lock.acquire()
511+
if not self.inst.no_backpressure.is_set() or self.inst.num_backpressure_waiters > 0:
512+
while True:
513+
self.inst.num_backpressure_waiters += 1
514+
maybe_go = self.inst.no_backpressure.wait()
515+
cancelled = await self.block_on(maybe_go, cancellable = True, unlock = True)
516+
self.inst.num_backpressure_waiters -= 1
517+
if cancelled:
518+
return Cancelled.TRUE
519+
if self.inst.no_backpressure.is_set():
520+
break
521+
return Cancelled.FALSE
534522

535-
async def wait_on(self, awaitable, sync, cancellable = False) -> bool:
536-
if sync:
537-
assert(not self.inst.calling_sync_import)
538-
self.inst.calling_sync_import = True
539-
else:
540-
self.maybe_start_pending_task()
523+
async def block_on(self, awaitable, cancellable = False, unlock = False) -> Cancelled:
524+
f = asyncio.ensure_future(awaitable)
525+
if f.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
526+
return Cancelled.FALSE
541527

542-
awaitable = asyncio.ensure_future(awaitable)
543-
if awaitable.done() and not DETERMINISTIC_PROFILE and random.randint(0,1):
544-
cancelled = Cancelled.FALSE
545-
else:
546-
cancelled = await self.on_block(awaitable)
547-
if cancelled and not cancellable:
548-
assert(self.state == Task.State.INITIAL)
549-
self.state = Task.State.PENDING_CANCEL
550-
cancelled = await self.on_block(awaitable)
551-
assert(not cancelled)
528+
if unlock and (self.opts.sync or self.opts.callback):
529+
self.inst.lock.release()
552530

553-
if sync:
554-
self.inst.calling_sync_import = False
555-
self.inst.async_waiting_tasks.notify_all()
556-
else:
557-
while self.inst.calling_sync_import:
558-
await self.inst.async_waiting_tasks.wait()
531+
cancelled = await self.on_block(f)
532+
if cancelled and not cancellable:
533+
assert(await self.on_block(f) == Cancelled.FALSE)
559534

560-
return cancelled
535+
if unlock and (self.opts.sync or self.opts.callback):
536+
acquired = asyncio.create_task(self.inst.lock.acquire())
537+
cancelled |= await self.on_block(acquired)
538+
if cancelled:
539+
assert(self.on_block(acquired) == Cancelled.FALSE)
540+
541+
if cancelled:
542+
assert(self.state == Task.State.INITIAL)
543+
if not cancellable:
544+
self.state = Task.State.PENDING_CANCEL
545+
return Cancelled.FALSE
546+
else:
547+
self.state = Task.State.CANCEL_DELIVERED
548+
return Cancelled.TRUE
549+
else:
550+
return Cancelled.FALSE
561551

562552
async def call_sync(self, callee, on_start, on_return):
563553
async def sync_on_block(awaitable):
@@ -567,42 +557,36 @@ async def sync_on_block(awaitable):
567557
assert(await self.on_block(awaitable) == Cancelled.FALSE)
568558
return Cancelled.FALSE
569559

570-
assert(not self.inst.calling_sync_import)
571-
self.inst.calling_sync_import = True
572560
await callee(self, on_start, on_return, sync_on_block)
573-
self.inst.calling_sync_import = False
574-
self.inst.async_waiting_tasks.notify_all()
575561

576-
async def wait_for_event(self, waitable_set, sync) -> EventTuple:
577-
if self.state == Task.State.PENDING_CANCEL:
562+
async def wait_for_event(self, waitable_set, cancellable, unlock) -> EventTuple:
563+
if self.state == Task.State.PENDING_CANCEL and cancellable:
578564
self.state = Task.State.CANCEL_DELIVERED
579565
return (EventCode.TASK_CANCELLED, 0, 0)
580566
else:
581567
waitable_set.num_waiting += 1
582568
e = None
583569
while not e:
584570
maybe_event = waitable_set.maybe_has_pending_event.wait()
585-
if await self.wait_on(maybe_event, sync, cancellable = True):
586-
assert(self.state == Task.State.INITIAL)
587-
self.state = Task.State.CANCEL_DELIVERED
571+
if await self.block_on(maybe_event, cancellable, unlock) == Cancelled.TRUE:
588572
return (EventCode.TASK_CANCELLED, 0, 0)
589573
e = waitable_set.poll()
590574
waitable_set.num_waiting -= 1
591575
return e
592576

593-
async def yield_(self, sync) -> EventTuple:
594-
if self.state == Task.State.PENDING_CANCEL:
577+
async def yield_(self, cancellable, unlock) -> EventTuple:
578+
if self.state == Task.State.PENDING_CANCEL and cancellable:
595579
self.state = Task.State.CANCEL_DELIVERED
596580
return (EventCode.TASK_CANCELLED, 0, 0)
597-
elif await self.wait_on(asyncio.sleep(0), sync, cancellable = True):
598-
assert(self.state == Task.State.INITIAL)
599-
self.state = Task.State.CANCEL_DELIVERED
581+
elif await self.block_on(asyncio.sleep(0), cancellable, unlock) == Cancelled.TRUE:
600582
return (EventCode.TASK_CANCELLED, 0, 0)
601583
else:
602584
return (EventCode.NONE, 0, 0)
603585

604-
async def poll_for_event(self, waitable_set, sync) -> Optional[EventTuple]:
605-
event_code,_,_ = e = await self.yield_(sync)
586+
async def poll_for_event(self, waitable_set, cancellable, unlock) -> Optional[EventTuple]:
587+
waitable_set.num_waiting += 1
588+
event_code,_,_ = e = await self.yield_(cancellable, unlock)
589+
waitable_set.num_waiting -= 1
606590
if event_code == EventCode.TASK_CANCELLED:
607591
return e
608592
elif (e := waitable_set.poll()):
@@ -624,13 +608,10 @@ def cancel(self):
624608
self.state = Task.State.RESOLVED
625609

626610
def exit(self):
627-
assert(scheduler.locked())
628611
trap_if(self.state != Task.State.RESOLVED)
629612
assert(self.num_borrows == 0)
630-
if self.opts.sync:
631-
assert(self.inst.calling_sync_export)
632-
self.inst.calling_sync_export = False
633-
self.maybe_start_pending_task()
613+
if self.opts.sync or self.opts.callback:
614+
self.inst.lock.release()
634615

635616
#### Subtask State
636617

@@ -1932,7 +1913,10 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
19321913

19331914
async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_block):
19341915
task = Task(opts, inst, ft, caller, on_resolve, on_block)
1935-
if not await task.enter():
1916+
task.trap_if_on_the_stack(inst)
1917+
if await task.enter() == Cancelled.TRUE:
1918+
task.cancel()
1919+
task.exit()
19361920
return
19371921

19381922
cx = LiftLowerContext(opts, inst, task)
@@ -1967,15 +1951,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
19671951
task.exit()
19681952
return
19691953
case CallbackCode.YIELD:
1970-
e = await task.yield_(sync = False)
1954+
e = await task.yield_(cancellable = True, unlock = True)
19711955
case CallbackCode.WAIT:
19721956
s = task.inst.table.get(si)
19731957
trap_if(not isinstance(s, WaitableSet))
1974-
e = await task.wait_for_event(s, sync = False)
1958+
e = await task.wait_for_event(s, cancellable = True, unlock = True)
19751959
case CallbackCode.POLL:
19761960
s = task.inst.table.get(si)
19771961
trap_if(not isinstance(s, WaitableSet))
1978-
e = await task.poll_for_event(s, sync = False)
1962+
e = await task.poll_for_event(s, cancellable = True, unlock = True)
19791963
event_code, p1, p2 = e
19801964
[packed] = await call_and_trap_on_throw(opts.callback, task, [event_code, p1, p2])
19811965

@@ -2115,7 +2099,11 @@ async def canon_context_set(t, i, task, v):
21152099

21162100
async def canon_backpressure_set(task, flat_args):
21172101
trap_if(task.opts.sync)
2118-
task.inst.backpressure = bool(flat_args[0])
2102+
assert(len(flat_args) == 1)
2103+
if flat_args[0] == 0:
2104+
task.inst.no_backpressure.set()
2105+
else:
2106+
task.inst.no_backpressure.clear()
21192107
return []
21202108

21212109
### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
21402128

21412129
### 🔀 `canon yield`
21422130

2143-
async def canon_yield(sync, task):
2131+
async def canon_yield(cancellable, task):
21442132
trap_if(not task.inst.may_leave)
2145-
event_code,_,_ = await task.yield_(sync)
2133+
event_code,_,_ = await task.yield_(cancellable, unlock = False)
21462134
match event_code:
21472135
case EventCode.NONE:
21482136
return [0]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
21572145

21582146
### 🔀 `canon waitable-set.wait`
21592147

2160-
async def canon_waitable_set_wait(sync, mem, task, si, ptr):
2148+
async def canon_waitable_set_wait(cancellable, mem, task, si, ptr):
21612149
trap_if(not task.inst.may_leave)
21622150
s = task.inst.table.get(si)
21632151
trap_if(not isinstance(s, WaitableSet))
2164-
e = await task.wait_for_event(s, sync)
2152+
e = await task.wait_for_event(s, cancellable, unlock = False)
21652153
return unpack_event(mem, task, ptr, e)
21662154

21672155
def unpack_event(mem, task, ptr, e: EventTuple):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
21732161

21742162
### 🔀 `canon waitable-set.poll`
21752163

2176-
async def canon_waitable_set_poll(sync, mem, task, si, ptr):
2164+
async def canon_waitable_set_poll(cancellable, mem, task, si, ptr):
21772165
trap_if(not task.inst.may_leave)
21782166
s = task.inst.table.get(si)
21792167
trap_if(not isinstance(s, WaitableSet))
2180-
e = await task.poll_for_event(s, sync)
2168+
e = await task.poll_for_event(s, cancellable, unlock = False)
21812169
return unpack_event(mem, task, ptr, e)
21822170

21832171
### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
22202208
while not subtask.resolved():
22212209
if subtask.has_pending_event():
22222210
_ = subtask.get_event()
2223-
await task.wait_on(subtask.wait_for_pending_event(), sync = True)
2211+
await task.block_on(subtask.wait_for_pending_event())
22242212
else:
22252213
if not subtask.resolved():
22262214
return [BLOCKED]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
22962284
e.copy(task.inst, buffer, on_copy, on_copy_done)
22972285

22982286
if opts.sync and not e.has_pending_event():
2299-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2287+
await task.block_on(e.wait_for_pending_event())
23002288

23012289
if e.has_pending_event():
23022290
code,index,payload = e.get_event()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
23422330
e.copy(task.inst, buffer, on_copy_done)
23432331

23442332
if opts.sync and not e.has_pending_event():
2345-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2333+
await task.block_on(e.wait_for_pending_event())
23462334

23472335
if e.has_pending_event():
23482336
code,index,payload = e.get_event()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
23752363
e.shared.cancel()
23762364
if not e.has_pending_event():
23772365
if sync:
2378-
await task.wait_on(e.wait_for_pending_event(), sync = True)
2366+
await task.block_on(e.wait_for_pending_event())
23792367
else:
23802368
return [BLOCKED]
23812369
code,index,payload = e.get_event()

0 commit comments

Comments
 (0)