7
7
from __future__ import annotations
8
8
from dataclasses import dataclass
9
9
from functools import partial
10
- from typing import Any , Optional , Callable , Awaitable , TypeVar , Generic , Literal
10
+ from typing import Any , Optional , Callable , TypeVar , Generic , Literal
11
11
from enum import Enum , IntEnum
12
12
import math
13
13
import struct
@@ -214,18 +214,18 @@ class CanonicalOptions(LiftLowerOptions):
214
214
215
215
### Runtime State
216
216
217
- scheduler = asyncio .Lock ()
218
-
219
217
#### Component Instance State
220
218
221
219
class ComponentInstance :
220
+ store : Store
222
221
table : Table
223
222
may_leave : bool
224
223
no_backpressure : asyncio .Event
225
224
num_backpressure_waiters : int
226
225
lock : asyncio .Lock
227
226
228
- def __init__ (self ):
227
+ def __init__ (self , store ):
228
+ self .store = store
229
229
self .table = Table ()
230
230
self .may_leave = True
231
231
self .no_backpressure = asyncio .Event ()
@@ -457,10 +457,6 @@ class Cancelled(IntEnum):
457
457
FALSE = 0
458
458
TRUE = 1
459
459
460
- OnStart = Callable [[], list [any ]]
461
- OnResolve = Callable [[Optional [list [any ]]], None ]
462
- OnBlock = Callable [[Awaitable ], Awaitable [Cancelled ]]
463
-
464
460
class Task :
465
461
class State (Enum ):
466
462
INITIAL = 1
@@ -473,24 +469,23 @@ class State(Enum):
473
469
inst : ComponentInstance
474
470
ft : FuncType
475
471
supertask : Optional [Task ]
476
- on_resolve : OnResolve
477
- on_block : OnBlock
472
+ thread : Thread
473
+ on_resolve : Callable [[ Optional [ list [ any ]]], None ]
478
474
num_borrows : int
479
475
context : ContextLocalStorage
480
476
481
- def __init__ (self , opts , inst , ft , supertask , on_resolve , on_block ):
477
+ def __init__ (self , opts , inst , ft , supertask , thread , on_resolve ):
482
478
self .state = Task .State .INITIAL
483
479
self .opts = opts
484
480
self .inst = inst
485
481
self .ft = ft
486
482
self .supertask = supertask
483
+ self .thread = thread
487
484
self .on_resolve = on_resolve
488
- self .on_block = on_block
489
485
self .num_borrows = 0
490
486
self .context = ContextLocalStorage ()
491
487
492
488
async def enter (self ):
493
- assert (scheduler .locked ())
494
489
self .trap_if_on_the_stack (self .inst )
495
490
if self .opts .sync or self .opts .callback :
496
491
if self .inst .lock .locked ():
@@ -530,15 +525,15 @@ async def wait_on(self, awaitable, cancellable = False, for_callback = False) ->
530
525
if for_callback :
531
526
self .inst .lock .release ()
532
527
533
- cancelled = await self .on_block (f )
528
+ cancelled = await self .thread . suspend (f )
534
529
if cancelled and not cancellable :
535
- assert (await self .on_block (f ) == Cancelled .FALSE )
530
+ assert (await self .thread . suspend (f ) == Cancelled .FALSE )
536
531
537
532
if for_callback :
538
533
acquired = asyncio .create_task (self .inst .lock .acquire ())
539
- cancelled |= await self .on_block (acquired )
534
+ cancelled |= await self .thread . suspend (acquired )
540
535
if cancelled :
541
- assert (self .on_block (acquired ) == Cancelled .FALSE )
536
+ assert (self .thread . suspend (acquired ) == Cancelled .FALSE )
542
537
543
538
if cancelled :
544
539
assert (self .state == Task .State .INITIAL )
@@ -551,16 +546,6 @@ async def wait_on(self, awaitable, cancellable = False, for_callback = False) ->
551
546
else :
552
547
return Cancelled .FALSE
553
548
554
- async def call_sync (self , callee , on_start , on_return ):
555
- async def sync_on_block (awaitable ):
556
- if await self .on_block (awaitable ) == Cancelled .TRUE :
557
- assert (self .state == Task .State .INITIAL )
558
- self .state = Task .State .PENDING_CANCEL
559
- assert (await self .on_block (awaitable ) == Cancelled .FALSE )
560
- return Cancelled .FALSE
561
-
562
- await callee (self , on_start , on_return , sync_on_block )
563
-
564
549
async def wait_for_event (self , waitable_set , cancellable , for_callback ) -> EventTuple :
565
550
if self .state == Task .State .PENDING_CANCEL and cancellable :
566
551
self .state = Task .State .CANCEL_DELIVERED
@@ -626,18 +611,16 @@ class State(IntEnum):
626
611
CANCELLED_BEFORE_RETURNED = 4
627
612
628
613
state : State
629
- task : Task
614
+ thread : Optional [ Thread ]
630
615
lenders : Optional [list [ResourceHandle ]]
631
- request_cancel_begin : asyncio .Future
632
- request_cancel_end : asyncio .Future
616
+ cancellation_requested : bool
633
617
634
- def __init__ (self , task ):
618
+ def __init__ (self ):
635
619
Waitable .__init__ (self )
636
620
self .state = Subtask .State .STARTING
637
- self .task = task
621
+ self .thread = None
638
622
self .lenders = []
639
- self .request_cancel_begin = asyncio .Future ()
640
- self .request_cancel_end = asyncio .Future ()
623
+ self .cancellation_requested = False
641
624
642
625
def resolved (self ):
643
626
match self .state :
@@ -649,44 +632,6 @@ def resolved(self):
649
632
Subtask .State .CANCELLED_BEFORE_RETURNED ):
650
633
return True
651
634
652
- async def request_cancel (self ):
653
- assert (not self .cancellation_requested () and not self .resolved ())
654
- self .request_cancel_begin .set_result (None )
655
- await self .request_cancel_end
656
-
657
- def cancellation_requested (self ):
658
- return self .request_cancel_begin .done ()
659
-
660
- async def call_async (self , callee , on_start , on_resolve ):
661
- async def do_call ():
662
- await callee (self .task , on_start , on_resolve , async_on_block )
663
- relinquish_control ()
664
-
665
- async def async_on_block (awaitable ):
666
- relinquish_control ()
667
- if not self .request_cancel_end .done ():
668
- await asyncio .wait ([awaitable , self .request_cancel_begin ],
669
- return_when = asyncio .FIRST_COMPLETED )
670
- if self .request_cancel_begin .done ():
671
- return Cancelled .TRUE
672
- else :
673
- await awaitable
674
- assert (awaitable .done ())
675
- await scheduler .acquire ()
676
- return Cancelled .FALSE
677
-
678
- def relinquish_control ():
679
- if not ret .done ():
680
- ret .set_result (None )
681
- elif self .request_cancel_begin .done () and not self .request_cancel_end .done ():
682
- self .request_cancel_end .set_result (None )
683
- else :
684
- scheduler .release ()
685
-
686
- ret = asyncio .Future ()
687
- asyncio .create_task (do_call ())
688
- await ret
689
-
690
635
def add_lender (self , lending_handle ):
691
636
assert (not self .resolve_delivered () and not self .resolved ())
692
637
lending_handle .num_lends += 1
@@ -927,6 +872,84 @@ def drop(self):
927
872
trap_if (not self .done )
928
873
FutureEnd .drop (self )
929
874
875
+ #### Thread State
876
+
877
+ class Thread :
878
+ store : Store
879
+ future : Optional [asyncio .Future ]
880
+ on_resume : Optional [asyncio .Future ]
881
+ on_suspend_or_exit : Optional [asyncio .Future ]
882
+ returned : bool
883
+
884
+ def __init__ (self , store , lifted_func , caller , on_start , on_resolve ):
885
+ self .store = store
886
+ self .future = None
887
+ self .on_resume = asyncio .Future ()
888
+ self .on_suspend_or_exit = None
889
+ self .returned = False
890
+ async def async_impl ():
891
+ assert (await self .on_resume == Cancelled .FALSE )
892
+ self .on_resume = None
893
+ await lifted_func (caller , self , on_start , on_resolve )
894
+ self .on_suspend_or_exit .set_result (None )
895
+ self .returned = True
896
+ asyncio .create_task (async_impl ())
897
+
898
+ async def resume (self , cancelled = Cancelled .FALSE ):
899
+ if self .future :
900
+ assert (cancelled or self .future .done ())
901
+ self .future = None
902
+ self .store .waiting .remove (self )
903
+ self .on_resume .set_result (cancelled )
904
+ assert (not self .on_suspend_or_exit )
905
+ self .on_suspend_or_exit = asyncio .Future ()
906
+ await self .on_suspend_or_exit
907
+ self .on_suspend_or_exit = None
908
+ if self .future :
909
+ self .store .waiting .append (self )
910
+
911
+ async def suspend (self , future ) -> Cancelled :
912
+ assert (not self .future )
913
+ self .future = future
914
+ self .on_suspend_or_exit .set_result (None )
915
+ self .on_suspend_or_exit = None
916
+ assert (not self .on_resume )
917
+ self .on_resume = asyncio .Future ()
918
+ cancelled = await self .on_resume
919
+ self .on_resume = None
920
+ return cancelled
921
+
922
+ #### Store State / Embedding API
923
+
924
+ class Store :
925
+ loop : asyncio .AbstractEventLoop
926
+ waiting : list [Thread ]
927
+
928
+ def __init__ (self ):
929
+ self .loop = asyncio .new_event_loop ()
930
+ self .waiting = []
931
+
932
+ ExportCall = Thread
933
+
934
+ def start_export_call (self , lifted_func , on_start , on_resolve ) -> ExportCall :
935
+ async def async_impl ():
936
+ caller = None
937
+ thread = Thread (self , lifted_func , caller , on_start , on_resolve )
938
+ await thread .resume ()
939
+ return thread
940
+ return self .loop .run_until_complete (async_impl ())
941
+
942
+ def tick (self ):
943
+ if not DETERMINISTIC_PROFILE :
944
+ random .shuffle (self .waiting )
945
+ for thread in self .waiting :
946
+ if thread .future .done ():
947
+ self .loop .run_until_complete (thread .resume ())
948
+ return
949
+
950
+ def export_call_finished (self , export_call : ExportCall ):
951
+ return export_call .returned
952
+
930
953
### Despecialization
931
954
932
955
def despecialize (t ):
@@ -1882,8 +1905,8 @@ def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1882
1905
1883
1906
### `canon lift`
1884
1907
1885
- async def canon_lift (opts , inst , ft , callee , caller , on_start , on_resolve , on_block ):
1886
- task = Task (opts , inst , ft , caller , on_resolve , on_block )
1908
+ async def canon_lift (opts , inst , ft , callee , caller , thread , on_start , on_resolve ):
1909
+ task = Task (opts , inst , ft , caller , thread , on_resolve )
1887
1910
if await task .enter () == Cancelled .TRUE :
1888
1911
task .cancel ()
1889
1912
task .exit ()
@@ -1958,7 +1981,7 @@ async def call_and_trap_on_throw(callee, task, args):
1958
1981
1959
1982
async def canon_lower (opts , ft , callee , task , flat_args ):
1960
1983
trap_if (not task .inst .may_leave )
1961
- subtask = Subtask (task )
1984
+ subtask = Subtask ()
1962
1985
1963
1986
cx = LiftLowerContext (opts , task .inst , subtask )
1964
1987
flat_ft = flatten_functype (opts , ft , 'lower' )
@@ -1984,7 +2007,7 @@ def on_start():
1984
2007
def on_resolve (result ):
1985
2008
on_progress ()
1986
2009
if result is None :
1987
- assert (subtask .cancellation_requested () )
2010
+ assert (subtask .cancellation_requested )
1988
2011
if subtask .state == Subtask .State .STARTING :
1989
2012
subtask .state = Subtask .State .CANCELLED_BEFORE_STARTED
1990
2013
else :
@@ -1996,13 +2019,19 @@ def on_resolve(result):
1996
2019
nonlocal flat_results
1997
2020
flat_results = lower_flat_values (cx , max_flat_results , result , ft .result_type (), flat_args )
1998
2021
2022
+ subtask .thread = Thread (task .inst .store , callee , task , on_start , on_resolve )
2023
+ await subtask .thread .resume ()
2024
+
1999
2025
if opts .sync :
2000
- await task .call_sync (callee , on_start , on_resolve )
2026
+ if not subtask .resolved ():
2027
+ done = asyncio .Event ()
2028
+ def on_progress ():
2029
+ done .set ()
2030
+ await task .wait_on (done .wait ())
2001
2031
assert (types_match_values (flat_ft .results , flat_results ))
2002
2032
subtask .deliver_resolve ()
2003
2033
return flat_results
2004
2034
else :
2005
- await subtask .call_async (callee , on_start , on_resolve )
2006
2035
if subtask .resolved ():
2007
2036
assert (flat_results == [])
2008
2037
subtask .deliver_resolve ()
@@ -2182,11 +2211,12 @@ async def canon_subtask_cancel(sync, task, i):
2182
2211
subtask = task .inst .table .get (i )
2183
2212
trap_if (not isinstance (subtask , Subtask ))
2184
2213
trap_if (subtask .resolve_delivered ())
2185
- trap_if (subtask .cancellation_requested () )
2214
+ trap_if (subtask .cancellation_requested )
2186
2215
if subtask .resolved ():
2187
2216
assert (subtask .has_pending_event ())
2188
2217
else :
2189
- await subtask .request_cancel ()
2218
+ subtask .cancellation_requested = True
2219
+ await subtask .thread .resume (Cancelled .TRUE )
2190
2220
if sync :
2191
2221
while not subtask .resolved ():
2192
2222
if subtask .has_pending_event ():
0 commit comments