Skip to content

Commit 880091f

Browse files
committed
Add cooperative threads
1 parent 8c9365c commit 880091f

File tree

2 files changed

+151
-75
lines changed

2 files changed

+151
-75
lines changed

design/mvp/canonical-abi/definitions.py

Lines changed: 93 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -355,22 +355,6 @@ def write(self, vs):
355355
assert(all(v == () for v in vs))
356356
self.progress += len(vs)
357357

358-
#### Context-Local Storage
359-
360-
class ContextLocalStorage:
361-
LENGTH = 1
362-
array: list[int]
363-
364-
def __init__(self):
365-
self.array = [0] * ContextLocalStorage.LENGTH
366-
367-
def set(self, i, v):
368-
assert(types_match_values(['i32'], [v]))
369-
self.array[i] = v
370-
371-
def get(self, i):
372-
return self.array[i]
373-
374358
#### Waitable State
375359

376360
class EventCode(IntEnum):
@@ -472,7 +456,6 @@ class State(Enum):
472456
supertask: Optional[Task]
473457
on_resolve: Callable[[Optional[list[any]]], None]
474458
num_borrows: int
475-
context: ContextLocalStorage
476459

477460
def __init__(self, opts, inst, ft, supertask, on_resolve):
478461
self.state = Task.State.INITIAL
@@ -482,7 +465,6 @@ def __init__(self, opts, inst, ft, supertask, on_resolve):
482465
self.supertask = supertask
483466
self.on_resolve = on_resolve
484467
self.num_borrows = 0
485-
self.context = ContextLocalStorage()
486468

487469
async def enter(self, thread):
488470
self.trap_if_on_the_stack(self.inst)
@@ -875,13 +857,19 @@ def drop(self):
875857

876858
class Thread:
877859
task: Task
878-
future: Optional[asyncio.Future]
860+
index: int
861+
context: list[int]
862+
future: Optional[Awaitable]
879863
on_resume: Optional[asyncio.Future]
880864
on_suspend_or_exit: Optional[asyncio.Future]
881865
returned: bool
882866

867+
CONTEXT_LENGTH = 1
868+
883869
def __init__(self, task, coro):
884870
self.task = task
871+
self.index = task.inst.table.add(self)
872+
self.context = [0] * Thread.CONTEXT_LENGTH
885873
self.future = None
886874
self.on_resume = asyncio.Future()
887875
self.on_suspend_or_exit = None
@@ -891,6 +879,7 @@ async def async_impl():
891879
self.on_resume = None
892880
await coro
893881
self.on_suspend_or_exit.set_result(None)
882+
self.task.inst.table.remove(self.index)
894883
self.returned = True
895884
asyncio.create_task(async_impl())
896885

@@ -918,6 +907,30 @@ async def suspend(self, future) -> Cancelled:
918907
self.on_resume = None
919908
return cancelled
920909

910+
async def switch(self, other: Thread) -> Cancelled:
911+
assert(not self.future and not other.future)
912+
assert(self.on_suspend_or_exit and not other.on_suspend_or_exit)
913+
other.on_suspend_or_exit = self.on_suspend_or_exit
914+
self.on_suspend_or_exit = None
915+
other.on_resume.set_result(Cancelled.FALSE)
916+
assert(not self.on_resume)
917+
self.on_resume = asyncio.Future()
918+
cancelled = await self.on_resume
919+
self.on_resume = None
920+
return cancelled
921+
922+
def yield_(self, other: Thread) -> Cancelled:
923+
# deterministically switch to other, but leave this thread unblocked
924+
TODO
925+
926+
def unblock(self, other: Thread):
927+
# unblock other, but deterministically keep running here
928+
TODO
929+
930+
def wait(self) -> Cancelled:
931+
# perform just the first half of switch
932+
TODO
933+
921934
#### Store State / Embedding API
922935

923936
class Store:
@@ -2097,19 +2110,76 @@ async def canon_resource_rep(rt, thread, i):
20972110
trap_if(h.rt is not rt)
20982111
return [h.rep]
20992112

2113+
### 🧵 `canon thread.index`
2114+
2115+
async def canon_thread_index(shared, thread):
2116+
assert(not shared)
2117+
return [thread.index]
2118+
2119+
### 🧵 `canon thread.new_indirect`
2120+
2121+
async def canon_thread_new_indirect(shared, ft, ftbl, thread, i, c):
2122+
assert(not shared)
2123+
inst = thread.task.inst
2124+
trap_if(not inst.may_leave)
2125+
f = ftbl.get(i)
2126+
trap_if(f is None)
2127+
trap_if(f.type != ft)
2128+
thread = Thread(thread.task, f(c))
2129+
return [thread.index]
2130+
2131+
### 🧵 `canon thread.switch`
2132+
2133+
async def canon_thread_switch(shared, thread, i):
2134+
assert(not shared)
2135+
trap_if(not thread.task.inst.may_leave)
2136+
other = thread.task.inst.table.get(i)
2137+
trap_if(not isinstance(other, Thread))
2138+
cancelled = await thread.switch(other)
2139+
return [ 1 if cancelled else 0 ]
2140+
2141+
### 🧵 `canon thread.yield`
2142+
2143+
async def canon_thread_yield(shared, thread, i):
2144+
assert(not shared)
2145+
trap_if(not thread.task.inst.may_leave)
2146+
other = thread.task.inst.table.get(i)
2147+
trap_if(not isinstance(other, Thread))
2148+
other.yield_(other)
2149+
return []
2150+
2151+
### 🧵 `canon thread.unblock`
2152+
2153+
async def canon_thread_unblock(shared, thread, i):
2154+
trap_if(not thread.task.inst.may_leave)
2155+
other = thread.task.inst.table.get(i)
2156+
trap_if(not isinstance(other, Thread))
2157+
thread.unblock()
2158+
return []
2159+
2160+
### 🧵 `canon thread.wait`
2161+
2162+
async def canon_thread_wait(shared, thread, i):
2163+
assert(not shared)
2164+
trap_if(not thread.task.inst.may_leave)
2165+
other = thread.task.inst.table.get(i)
2166+
trap_if(not isinstance(other, Thread))
2167+
cancelled = await thread.suspend()
2168+
return [ 1 if cancelled else 0 ]
2169+
21002170
### 🔀 `canon context.get`
21012171

21022172
async def canon_context_get(t, i, thread):
21032173
assert(t == 'i32')
2104-
assert(i < ContextLocalStorage.LENGTH)
2105-
return [thread.task.context.get(i)]
2174+
assert(i < Thread.CONTEXT_LENGTH)
2175+
return [thread.context[i]]
21062176

21072177
### 🔀 `canon context.set`
21082178

21092179
async def canon_context_set(t, i, thread, v):
21102180
assert(t == 'i32')
2111-
assert(i < ContextLocalStorage.LENGTH)
2112-
thread.task.context.set(i, v)
2181+
assert(i < Thread.CONTEXT_LENGTH)
2182+
thread.context[i] = v
21132183
return []
21142184

21152185
### 🔀 `canon backpressure.set`

0 commit comments

Comments
 (0)