@@ -221,22 +221,17 @@ class CanonicalOptions(LiftLowerOptions):
221
221
class ComponentInstance :
222
222
table : Table
223
223
may_leave : bool
224
- backpressure : bool
225
- calling_sync_export : bool
226
- calling_sync_import : bool
227
- pending_tasks : list [tuple [Task , asyncio .Future ]]
228
- starting_pending_task : bool
229
- async_waiting_tasks : asyncio .Condition
224
+ no_backpressure : asyncio .Event
225
+ num_backpressure_waiters : int
226
+ lock : asyncio .Lock
230
227
231
228
def __init__ (self ):
232
229
self .table = Table ()
233
230
self .may_leave = True
234
- self .backpressure = False
235
- self .calling_sync_export = False
236
- self .calling_sync_import = False
237
- self .pending_tasks = []
238
- self .starting_pending_task = False
239
- self .async_waiting_tasks = asyncio .Condition (scheduler )
231
+ self .no_backpressure = asyncio .Event ()
232
+ self .no_backpressure .set ()
233
+ self .num_backpressure_waiters = 0
234
+ self .lock = asyncio .Lock ()
240
235
241
236
#### Table State
242
237
@@ -464,7 +459,7 @@ class Cancelled(IntEnum):
464
459
465
460
OnStart = Callable [[], list [any ]]
466
461
OnResolve = Callable [[Optional [list [any ]]], None ]
467
- OnBlock = Callable [[Awaitable ], Awaitable [Cancelled ]]
462
+ OnBlock = Callable [[asyncio . Future ], Awaitable [Cancelled ]]
468
463
469
464
class Task :
470
465
class State (Enum ):
@@ -497,67 +492,64 @@ def __init__(self, opts, inst, ft, supertask, on_resolve, on_block):
497
492
async def enter (self ):
498
493
assert (scheduler .locked ())
499
494
self .trap_if_on_the_stack (self .inst )
500
- if not self .may_enter (self ) or self .inst .pending_tasks :
501
- f = asyncio .Future ()
502
- self .inst .pending_tasks .append ((self , f ))
503
- if await self .on_block (f ) == Cancelled .TRUE :
504
- [i ] = [i for i ,(t ,_ ) in enumerate (self .inst .pending_tasks ) if t == self ]
505
- self .inst .pending_tasks .pop (i )
506
- self .on_resolve (None )
507
- return Cancelled .FALSE
508
- assert (self .may_enter (self ) and self .inst .starting_pending_task )
509
- self .inst .starting_pending_task = False
510
- if self .opts .sync :
511
- self .inst .calling_sync_export = True
512
- return True
495
+ if self .opts .sync or self .opts .callback :
496
+ if self .inst .lock .locked ():
497
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
498
+ cancelled = await self .wait_on (acquired , cancellable = True , for_callback = False )
499
+ if cancelled :
500
+ if acquired .done ():
501
+ self .inst .lock .release ()
502
+ else :
503
+ acquired .cancel ()
504
+ return Cancelled .TRUE
505
+ else :
506
+ await self .inst .lock .acquire ()
507
+ if not self .inst .no_backpressure .is_set () or self .inst .num_backpressure_waiters > 0 :
508
+ while True :
509
+ self .inst .num_backpressure_waiters += 1
510
+ maybe_go = self .inst .no_backpressure .wait ()
511
+ cancelled = await self .wait_on (maybe_go , cancellable = True , for_callback = False )
512
+ self .inst .num_backpressure_waiters -= 1
513
+ if cancelled :
514
+ return Cancelled .TRUE
515
+ if self .inst .no_backpressure .is_set ():
516
+ break
517
+ return Cancelled .FALSE
513
518
514
519
def trap_if_on_the_stack (self , inst ):
515
520
c = self .supertask
516
521
while c is not None :
517
522
trap_if (c .inst is inst )
518
523
c = c .supertask
519
524
520
- def may_enter (self , pending_task ):
521
- return not self .inst .backpressure and \
522
- not self .inst .calling_sync_import and \
523
- not (self .inst .calling_sync_export and pending_task .opts .sync )
524
-
525
- def maybe_start_pending_task (self ):
526
- if self .inst .starting_pending_task :
527
- return
528
- for i ,(pending_task ,pending_future ) in enumerate (self .inst .pending_tasks ):
529
- if self .may_enter (pending_task ):
530
- self .inst .pending_tasks .pop (i )
531
- self .inst .starting_pending_task = True
532
- pending_future .set_result (None )
533
- return
525
+ async def wait_on (self , awaitable , cancellable = False , for_callback = False ) -> Cancelled :
526
+ f = asyncio .ensure_future (awaitable )
527
+ if f .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
528
+ return Cancelled .FALSE
534
529
535
- async def wait_on (self , awaitable , sync , cancellable = False ) -> bool :
536
- if sync :
537
- assert (not self .inst .calling_sync_import )
538
- self .inst .calling_sync_import = True
539
- else :
540
- self .maybe_start_pending_task ()
530
+ if for_callback :
531
+ self .inst .lock .release ()
541
532
542
- awaitable = asyncio .ensure_future (awaitable )
543
- if awaitable .done () and not DETERMINISTIC_PROFILE and random .randint (0 ,1 ):
544
- cancelled = Cancelled .FALSE
545
- else :
546
- cancelled = await self .on_block (awaitable )
547
- if cancelled and not cancellable :
548
- assert (self .state == Task .State .INITIAL )
549
- self .state = Task .State .PENDING_CANCEL
550
- cancelled = await self .on_block (awaitable )
551
- assert (not cancelled )
533
+ cancelled = await self .on_block (f )
534
+ if cancelled and not cancellable :
535
+ assert (await self .on_block (f ) == Cancelled .FALSE )
552
536
553
- if sync :
554
- self .inst .calling_sync_import = False
555
- self .inst .async_waiting_tasks .notify_all ()
556
- else :
557
- while self .inst .calling_sync_import :
558
- await self .inst .async_waiting_tasks .wait ()
537
+ if for_callback :
538
+ acquired = asyncio .create_task (self .inst .lock .acquire ())
539
+ cancelled |= await self .on_block (acquired )
540
+ if cancelled :
541
+ assert (self .on_block (acquired ) == Cancelled .FALSE )
559
542
560
- return cancelled
543
+ if cancelled :
544
+ assert (self .state == Task .State .INITIAL )
545
+ if not cancellable :
546
+ self .state = Task .State .PENDING_CANCEL
547
+ return Cancelled .FALSE
548
+ else :
549
+ self .state = Task .State .CANCEL_DELIVERED
550
+ return Cancelled .TRUE
551
+ else :
552
+ return Cancelled .FALSE
561
553
562
554
async def call_sync (self , callee , on_start , on_return ):
563
555
async def sync_on_block (awaitable ):
@@ -567,42 +559,36 @@ async def sync_on_block(awaitable):
567
559
assert (await self .on_block (awaitable ) == Cancelled .FALSE )
568
560
return Cancelled .FALSE
569
561
570
- assert (not self .inst .calling_sync_import )
571
- self .inst .calling_sync_import = True
572
562
await callee (self , on_start , on_return , sync_on_block )
573
- self .inst .calling_sync_import = False
574
- self .inst .async_waiting_tasks .notify_all ()
575
563
576
- async def wait_for_event (self , waitable_set , sync ) -> EventTuple :
577
- if self .state == Task .State .PENDING_CANCEL :
564
+ async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
565
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
578
566
self .state = Task .State .CANCEL_DELIVERED
579
567
return (EventCode .TASK_CANCELLED , 0 , 0 )
580
568
else :
581
569
waitable_set .num_waiting += 1
582
570
e = None
583
571
while not e :
584
572
maybe_event = waitable_set .maybe_has_pending_event .wait ()
585
- if await self .wait_on (maybe_event , sync , cancellable = True ):
586
- assert (self .state == Task .State .INITIAL )
587
- self .state = Task .State .CANCEL_DELIVERED
573
+ if await self .wait_on (maybe_event , cancellable , for_callback ) == Cancelled .TRUE :
588
574
return (EventCode .TASK_CANCELLED , 0 , 0 )
589
575
e = waitable_set .poll ()
590
576
waitable_set .num_waiting -= 1
591
577
return e
592
578
593
- async def yield_ (self , sync ) -> EventTuple :
594
- if self .state == Task .State .PENDING_CANCEL :
579
+ async def yield_ (self , cancellable , for_callback ) -> EventTuple :
580
+ if self .state == Task .State .PENDING_CANCEL and cancellable :
595
581
self .state = Task .State .CANCEL_DELIVERED
596
582
return (EventCode .TASK_CANCELLED , 0 , 0 )
597
- elif await self .wait_on (asyncio .sleep (0 ), sync , cancellable = True ):
598
- assert (self .state == Task .State .INITIAL )
599
- self .state = Task .State .CANCEL_DELIVERED
583
+ elif await self .wait_on (asyncio .sleep (0 ), cancellable , for_callback ) == Cancelled .TRUE :
600
584
return (EventCode .TASK_CANCELLED , 0 , 0 )
601
585
else :
602
586
return (EventCode .NONE , 0 , 0 )
603
587
604
- async def poll_for_event (self , waitable_set , sync ) -> Optional [EventTuple ]:
605
- event_code ,_ ,_ = e = await self .yield_ (sync )
588
+ async def poll_for_event (self , waitable_set , cancellable , for_callback ) -> Optional [EventTuple ]:
589
+ waitable_set .num_waiting += 1
590
+ event_code ,_ ,_ = e = await self .yield_ (cancellable , for_callback )
591
+ waitable_set .num_waiting -= 1
606
592
if event_code == EventCode .TASK_CANCELLED :
607
593
return e
608
594
elif (e := waitable_set .poll ()):
@@ -624,13 +610,10 @@ def cancel(self):
624
610
self .state = Task .State .RESOLVED
625
611
626
612
def exit (self ):
627
- assert (scheduler .locked ())
628
613
trap_if (self .state != Task .State .RESOLVED )
629
614
assert (self .num_borrows == 0 )
630
- if self .opts .sync :
631
- assert (self .inst .calling_sync_export )
632
- self .inst .calling_sync_export = False
633
- self .maybe_start_pending_task ()
615
+ if self .opts .sync or self .opts .callback :
616
+ self .inst .lock .release ()
634
617
635
618
#### Subtask State
636
619
@@ -1932,7 +1915,9 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1932
1915
1933
1916
async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
1934
1917
task = Task (opts , inst , ft , caller , on_resolve , on_block )
1935
- if not await task .enter ():
1918
+ if await task .enter () == Cancelled .TRUE :
1919
+ task .cancel ()
1920
+ task .exit ()
1936
1921
return
1937
1922
1938
1923
cx = LiftLowerContext (opts , inst , task )
@@ -1967,15 +1952,15 @@ async def canon_lift(opts, inst, ft, callee, caller, on_start, on_resolve, on_bl
1967
1952
task .exit ()
1968
1953
return
1969
1954
case CallbackCode .YIELD :
1970
- e = await task .yield_ (sync = False )
1955
+ e = await task .yield_ (cancellable = True , for_callback = True )
1971
1956
case CallbackCode .WAIT :
1972
1957
s = task .inst .table .get (si )
1973
1958
trap_if (not isinstance (s , WaitableSet ))
1974
- e = await task .wait_for_event (s , sync = False )
1959
+ e = await task .wait_for_event (s , cancellable = True , for_callback = True )
1975
1960
case CallbackCode .POLL :
1976
1961
s = task .inst .table .get (si )
1977
1962
trap_if (not isinstance (s , WaitableSet ))
1978
- e = await task .poll_for_event (s , sync = False )
1963
+ e = await task .poll_for_event (s , cancellable = True , for_callback = True )
1979
1964
event_code , p1 , p2 = e
1980
1965
[packed ] = await call_and_trap_on_throw (opts .callback , task , [event_code , p1 , p2 ])
1981
1966
@@ -2114,8 +2099,11 @@ async def canon_context_set(t, i, task, v):
2114
2099
### 🔀 `canon backpressure.set`
2115
2100
2116
2101
async def canon_backpressure_set (task , flat_args ):
2117
- trap_if (task .opts .sync )
2118
- task .inst .backpressure = bool (flat_args [0 ])
2102
+ assert (len (flat_args ) == 1 )
2103
+ if flat_args [0 ] == 0 :
2104
+ task .inst .no_backpressure .set ()
2105
+ else :
2106
+ task .inst .no_backpressure .clear ()
2119
2107
return []
2120
2108
2121
2109
### 🔀 `canon task.return`
@@ -2140,9 +2128,9 @@ async def canon_task_cancel(task):
2140
2128
2141
2129
### 🔀 `canon yield`
2142
2130
2143
- async def canon_yield (sync , task ):
2131
+ async def canon_yield (cancellable , task ):
2144
2132
trap_if (not task .inst .may_leave )
2145
- event_code ,_ ,_ = await task .yield_ (sync )
2133
+ event_code ,_ ,_ = await task .yield_ (cancellable , for_callback = False )
2146
2134
match event_code :
2147
2135
case EventCode .NONE :
2148
2136
return [0 ]
@@ -2157,11 +2145,11 @@ async def canon_waitable_set_new(task):
2157
2145
2158
2146
### 🔀 `canon waitable-set.wait`
2159
2147
2160
- async def canon_waitable_set_wait (sync , mem , task , si , ptr ):
2148
+ async def canon_waitable_set_wait (cancellable , mem , task , si , ptr ):
2161
2149
trap_if (not task .inst .may_leave )
2162
2150
s = task .inst .table .get (si )
2163
2151
trap_if (not isinstance (s , WaitableSet ))
2164
- e = await task .wait_for_event (s , sync )
2152
+ e = await task .wait_for_event (s , cancellable , for_callback = False )
2165
2153
return unpack_event (mem , task , ptr , e )
2166
2154
2167
2155
def unpack_event (mem , task , ptr , e : EventTuple ):
@@ -2173,11 +2161,11 @@ def unpack_event(mem, task, ptr, e: EventTuple):
2173
2161
2174
2162
### 🔀 `canon waitable-set.poll`
2175
2163
2176
- async def canon_waitable_set_poll (sync , mem , task , si , ptr ):
2164
+ async def canon_waitable_set_poll (cancellable , mem , task , si , ptr ):
2177
2165
trap_if (not task .inst .may_leave )
2178
2166
s = task .inst .table .get (si )
2179
2167
trap_if (not isinstance (s , WaitableSet ))
2180
- e = await task .poll_for_event (s , sync )
2168
+ e = await task .poll_for_event (s , cancellable , for_callback = False )
2181
2169
return unpack_event (mem , task , ptr , e )
2182
2170
2183
2171
### 🔀 `canon waitable-set.drop`
@@ -2220,7 +2208,7 @@ async def canon_subtask_cancel(sync, task, i):
2220
2208
while not subtask .resolved ():
2221
2209
if subtask .has_pending_event ():
2222
2210
_ = subtask .get_event ()
2223
- await task .wait_on (subtask .wait_for_pending_event (), sync = True )
2211
+ await task .wait_on (subtask .wait_for_pending_event ())
2224
2212
else :
2225
2213
if not subtask .resolved ():
2226
2214
return [BLOCKED ]
@@ -2296,7 +2284,7 @@ def on_copy_done(result):
2296
2284
e .copy (task .inst , buffer , on_copy , on_copy_done )
2297
2285
2298
2286
if opts .sync and not e .has_pending_event ():
2299
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2287
+ await task .wait_on (e .wait_for_pending_event ())
2300
2288
2301
2289
if e .has_pending_event ():
2302
2290
code ,index ,payload = e .get_event ()
@@ -2342,7 +2330,7 @@ def on_copy_done(result):
2342
2330
e .copy (task .inst , buffer , on_copy_done )
2343
2331
2344
2332
if opts .sync and not e .has_pending_event ():
2345
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2333
+ await task .wait_on (e .wait_for_pending_event ())
2346
2334
2347
2335
if e .has_pending_event ():
2348
2336
code ,index ,payload = e .get_event ()
@@ -2375,7 +2363,7 @@ async def cancel_copy(EndT, event_code, stream_or_future_t, sync, task, i):
2375
2363
e .shared .cancel ()
2376
2364
if not e .has_pending_event ():
2377
2365
if sync :
2378
- await task .wait_on (e .wait_for_pending_event (), sync = True )
2366
+ await task .wait_on (e .wait_for_pending_event ())
2379
2367
else :
2380
2368
return [BLOCKED ]
2381
2369
code ,index ,payload = e .get_event ()
0 commit comments