@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221
221
class ComponentInstance :
222
222
table : Table
223
223
may_leave : bool
224
- backpressure : bool
225
- calling_sync_export : bool
226
- calling_sync_import : bool
227
- pending_tasks : list [tuple [Task , asyncio .Future ]]
228
- starting_pending_task : bool
229
- async_waiting_tasks : asyncio .Condition
224
+ no_backpressure : asyncio .Event
225
+ num_backpressure_waiters : int
226
+ lock : asyncio .Lock
230
227
231
228
def __init__ (self ):
232
229
self .table = Table ()
233
230
self .may_leave = True
234
- self .backpressure = False
235
- self .calling_sync_export = False
236
- self .calling_sync_import = False
237
- self .pending_tasks = []
238
- self .starting_pending_task = False
239
- self .async_waiting_tasks = asyncio .Condition (scheduler )
231
+ self .no_backpressure = asyncio .Event ()
232
+ self .no_backpressure .set ()
233
+ self .num_backpressure_waiters = 0
234
+ self .lock = asyncio .Lock ()
240
235
241
236
#### Table State
242
237
@@ -464,7 +459,7 @@ class Cancelled(IntEnum):
464
459
465
460
OnStart = Callable [[], list [any ]]
466
461
OnResolve = Callable [[Optional [list [any ]]], None ]
467
- OnBlock = Callable [[Awaitable ], Awaitable [Cancelled ]]
462
+ OnBlock = Callable [[asyncio . Future ], Awaitable [Cancelled ]]
468
463
469
464
class Task :
470
465
class State (Enum ):
@@ -494,70 +489,65 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
494
489
self .num_borrows = 0
495
490
self .context = ContextLocalStorage ()
496
491
497
- async def enter (self ):
498
- assert (scheduler .locked ())
499
- self .trap_if_on_the_stack (self .inst )
500
- if not self .may_enter (self ) or self .inst .pending_tasks :
501
- f = asyncio .Future ()
502
- self .inst .pending_tasks .append ((self , f ))
503
- if await self .on_block (f ) == Cancelled .TRUE :
504
- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505
- self .inst .pending_tasks .pop (i )
506
- self .on_resolve (None )
507
- return Cancelled .FALSE
508
- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509
- self .inst .starting_pending_task = False
510
- if self .opts .sync :
511
- self .inst .calling_sync_export = True
512
- return True
513
-
514
492
def trap_if_on_the_stack (self , inst ):
515
493
c = self .supertask
516
494
while c is not None :
517
495
trap_if (c .inst is inst )
518
496
c = c .supertask
519
497
520
- def may_enter (self , pending_task ):
521
- return not self .inst .backpressure and \
522
- not self .inst .calling_sync_import and \
523
- not (self .inst .calling_sync_export and pending_task .opts .sync )
524
-
525
- def maybe_start_pending_task (self ):
526
- if self .inst .starting_pending_task :
527
- return
528
- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529
- if self .may_enter (pending_task ):
530
- self .inst .pending_tasks .pop (i )
531
- self .inst .starting_pending_task = True
532
- pending_future .set_result (None )
533
- return
498
+ async def enter (self ):
499
+ if self .opts .sync or self .opts .callback :
500
+ if self .inst .lock .locked ():
501
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
502
+ cancelled = await self .block_on (acquired , cancellable = True )
503
+ if cancelled :
504
+ if acquired .done ():
505
+ self .inst .lock .release ()
506
+ else :
507
+ acquired .cancel ()
508
+ return Cancelled .TRUE
509
+ else :
510
+ await self .inst .lock .acquire ()
511
+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
512
+ while True :
513
+ self .inst .num_backpressure_waiters += 1
514
+ maybe_go = self .inst .no_backpressure .wait ()
515
+ cancelled = await self .block_on (maybe_go , cancellable = True , unlock = True )
516
+ self .inst .num_backpressure_waiters -= 1
517
+ if cancelled :
518
+ return Cancelled .TRUE
519
+ if self .inst .no_backpressure .is_set ():
520
+ break
521
+ return Cancelled .FALSE
534
522
535
- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536
- if sync :
537
- assert (not self .inst .calling_sync_import )
538
- self .inst .calling_sync_import = True
539
- else :
540
- self .maybe_start_pending_task ()
523
+ async def block_on (self , awaitable , cancellable = False , unlock = False ) -> Cancelled :
524
+ f = asyncio .ensure_future (awaitable )
525
+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
526
+ return Cancelled .FALSE
541
527
542
- awaitable = asyncio .ensure_future (awaitable )
543
- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544
- cancelled = Cancelled .FALSE
545
- else :
546
- cancelled = await self .on_block (awaitable )
547
- if cancelled and not cancellable :
548
- assert (self .state == Task .State .INITIAL )
549
- self .state = Task .State .PENDING_CANCEL
550
- cancelled = await self .on_block (awaitable )
551
- assert (not cancelled )
528
+ if unlock and (self .opts .sync or self .opts .callback ):
529
+ self .inst .lock .release ()
552
530
553
- if sync :
554
- self .inst .calling_sync_import = False
555
- self .inst .async_waiting_tasks .notify_all ()
556
- else :
557
- while self .inst .calling_sync_import :
558
- await self .inst .async_waiting_tasks .wait ()
531
+ cancelled = await self .on_block (f )
532
+ if cancelled and not cancellable :
533
+ assert (await self .on_block (f ) == Cancelled .FALSE )
559
534
560
- return cancelled
535
+ if unlock and (self .opts .sync or self .opts .callback ):
536
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
537
+ cancelled |= await self .on_block (acquired )
538
+ if cancelled :
539
+ assert (self .on_block (acquired ) == Cancelled .FALSE )
540
+
541
+ if cancelled :
542
+ assert (self .state == Task .State .INITIAL )
543
+ if not cancellable :
544
+ self .state = Task .State .PENDING_CANCEL
545
+ return Cancelled .FALSE
546
+ else :
547
+ self .state = Task .State .CANCEL_DELIVERED
548
+ return Cancelled .TRUE
549
+ else :
550
+ return Cancelled .FALSE
561
551
562
552
async def call_sync (self , callee , on_start , on_return ):
563
553
async def sync_on_block (awaitable ):
@@ -567,42 +557,36 @@ async def sync_on_block(awaitable):
567
557
assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568
558
return Cancelled .FALSE
569
559
570
- assert (not self .inst .calling_sync_import )
571
- self .inst .calling_sync_import = True
572
560
await callee (self , on_start , on_return , sync_on_block )
573
- self .inst .calling_sync_import = False
574
- self .inst .async_waiting_tasks .notify_all ()
575
561
576
- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577
- if self .state == Task .State .PENDING_CANCEL :
562
+ async def wait_for_event (self , waitable_set , cancellable , unlock ) -> EventTuple :
563
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578
564
self .state = Task .State .CANCEL_DELIVERED
579
565
return (EventCode .TASK_CANCELLED , 0 , 0 )
580
566
else :
581
567
waitable_set .num_waiting += 1
582
568
e = None
583
569
while not e :
584
570
maybe_event = waitable_set .maybe_has_pending_event .wait ()
585
- if await self .wait_on (maybe_event , sync , cancellable = True ):
586
- assert (self .state == Task .State .INITIAL )
587
- self .state = Task .State .CANCEL_DELIVERED
571
+ if await self .block_on (maybe_event , cancellable , unlock ) == Cancelled .TRUE :
588
572
return (EventCode .TASK_CANCELLED , 0 , 0 )
589
573
e = waitable_set .poll ()
590
574
waitable_set .num_waiting -= 1
591
575
return e
592
576
593
- async def yield_ (self , sync ) -> EventTuple :
594
- if self .state == Task .State .PENDING_CANCEL :
577
+ async def yield_ (self , cancellable , unlock ) -> EventTuple :
578
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
595
579
self .state = Task .State .CANCEL_DELIVERED
596
580
return (EventCode .TASK_CANCELLED , 0 , 0 )
597
- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598
- assert (self .state == Task .State .INITIAL )
599
- self .state = Task .State .CANCEL_DELIVERED
581
+ elif await self .block_on (asyncio .sleep (0 ), cancellable , unlock ) == Cancelled .TRUE :
600
582
return (EventCode .TASK_CANCELLED , 0 , 0 )
601
583
else :
602
584
return (EventCode .NONE , 0 , 0 )
603
585
604
- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605
- event_code ,_ ,_ = e = await self .yield_ (sync )
586
+ async def poll_for_event (self , waitable_set , cancellable , unlock ) -> Optional [EventTuple ]:
587
+ waitable_set .num_waiting += 1
588
+ event_code ,_ ,_ = e = await self .yield_ (cancellable , unlock )
589
+ waitable_set .num_waiting -= 1
606
590
if event_code == EventCode .TASK_CANCELLED :
607
591
return e
608
592
elif (e := waitable_set .poll ()):
@@ -624,13 +608,10 @@ def cancel(self):
624
608
self .state = Task .State .RESOLVED
625
609
626
610
def exit (self ):
627
- assert (scheduler .locked ())
628
611
trap_if (self .state != Task .State .RESOLVED )
629
612
assert (self .num_borrows == 0 )
630
- if self .opts .sync :
631
- assert (self .inst .calling_sync_export )
632
- self .inst .calling_sync_export = False
633
- self .maybe_start_pending_task ()
613
+ if self .opts .sync or self .opts .callback :
614
+ self .inst .lock .release ()
634
615
635
616
#### Subtask State
636
617
@@ -1932,7 +1913,10 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1932
1913
1933
1914
async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
1934
1915
task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935
- if not await task .enter ():
1916
+ task .trap_if_on_the_stack (inst )
1917
+ if await task .enter () == Cancelled .TRUE :
1918
+ task .cancel ()
1919
+ task .exit ()
1936
1920
return
1937
1921
1938
1922
cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1951,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
1967
1951
task .exit ()
1968
1952
return
1969
1953
case CallbackCode .YIELD :
1970
- e = await task .yield_ (sync = False )
1954
+ e = await task .yield_ (cancellable = True , unlock = True )
1971
1955
case CallbackCode .WAIT :
1972
1956
s = task .inst .table .get (si )
1973
1957
trap_if (not isinstance (s , WaitableSet ))
1974
- e = await task .wait_for_event (s , sync = False )
1958
+ e = await task .wait_for_event (s , cancellable = True , unlock = True )
1975
1959
case CallbackCode .POLL :
1976
1960
s = task .inst .table .get (si )
1977
1961
trap_if (not isinstance (s , WaitableSet ))
1978
- e = await task .poll_for_event (s , sync = False )
1962
+ e = await task .poll_for_event (s , cancellable = True , unlock = True )
1979
1963
event_code , p1 , p2 = e
1980
1964
[packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
1981
1965
@@ -2115,7 +2099,11 @@ async def canon_context_set(t, i, task, v):
2115
2099
2116
2100
async def canon_backpressure_set (task , flat_args ):
2117
2101
trap_if (task .opts .sync )
2118
- task .inst .backpressure = bool (flat_args [0 ])
2102
+ assert (len (flat_args ) == 1 )
2103
+ if flat_args [0 ] == 0 :
2104
+ task .inst .no_backpressure .set ()
2105
+ else :
2106
+ task .inst .no_backpressure .clear ()
2119
2107
return []
2120
2108
2121
2109
### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
2140
2128
2141
2129
### 🔀 `canon yield`
2142
2130
2143
- async def canon_yield (sync , task ):
2131
+ async def canon_yield (cancellable , task ):
2144
2132
trap_if (not task .inst .may_leave )
2145
- event_code ,_ ,_ = await task .yield_ (sync )
2133
+ event_code ,_ ,_ = await task .yield_ (cancellable , unlock = False )
2146
2134
match event_code :
2147
2135
case EventCode .NONE :
2148
2136
return [0 ]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
2157
2145
2158
2146
### 🔀 `canon waitable-set.wait`
2159
2147
2160
- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2148
+ async def canon_waitable_set_wait (cancellable , mem , task , si , ptr ):
2161
2149
trap_if (not task .inst .may_leave )
2162
2150
s = task .inst .table .get (si )
2163
2151
trap_if (not isinstance (s , WaitableSet ))
2164
- e = await task .wait_for_event (s , sync )
2152
+ e = await task .wait_for_event (s , cancellable , unlock = False )
2165
2153
return unpack_event (mem , task , ptr , e )
2166
2154
2167
2155
def unpack_event (mem , task , ptr , e : EventTuple ):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
2173
2161
2174
2162
### 🔀 `canon waitable-set.poll`
2175
2163
2176
- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2164
+ async def canon_waitable_set_poll (cancellable , mem , task , si , ptr ):
2177
2165
trap_if (not task .inst .may_leave )
2178
2166
s = task .inst .table .get (si )
2179
2167
trap_if (not isinstance (s , WaitableSet ))
2180
- e = await task .poll_for_event (s , sync )
2168
+ e = await task .poll_for_event (s , cancellable , unlock = False )
2181
2169
return unpack_event (mem , task , ptr , e )
2182
2170
2183
2171
### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
2220
2208
while not subtask .resolved ():
2221
2209
if subtask .has_pending_event ():
2222
2210
_ = subtask .get_event ()
2223
- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2211
+ await task .block_on (subtask .wait_for_pending_event ())
2224
2212
else :
2225
2213
if not subtask .resolved ():
2226
2214
return [BLOCKED ]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
2296
2284
e .copy (task .inst , buffer , on_copy , on_copy_done )
2297
2285
2298
2286
if opts .sync and not e .has_pending_event ():
2299
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2287
+ await task .block_on (e .wait_for_pending_event ())
2300
2288
2301
2289
if e .has_pending_event ():
2302
2290
code ,index ,payload = e .get_event ()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
2342
2330
e .copy (task .inst , buffer , on_copy_done )
2343
2331
2344
2332
if opts .sync and not e .has_pending_event ():
2345
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2333
+ await task .block_on (e .wait_for_pending_event ())
2346
2334
2347
2335
if e .has_pending_event ():
2348
2336
code ,index ,payload = e .get_event ()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
2375
2363
e .shared .cancel ()
2376
2364
if not e .has_pending_event ():
2377
2365
if sync :
2378
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2366
+ await task .block_on (e .wait_for_pending_event ())
2379
2367
else :
2380
2368
return [BLOCKED ]
2381
2369
code ,index ,payload = e .get_event ()
0 commit comments