Skip to content

Commit 8d5e885

Browse files
committed
Tweak async lift/lower ABI to be more like sync ABI
1 parent cdad9c3 commit 8d5e885

File tree

3 files changed

+53
-95
lines changed

3 files changed

+53
-95
lines changed

design/mvp/CanonicalABI.md

Lines changed: 32 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1525,35 +1525,39 @@ MAX_FLAT_PARAMS = 16
15251525
MAX_FLAT_RESULTS = 1
15261526

15271527
def flatten_functype(opts, ft, context):
1528+
flat_params = flatten_types(ft.param_types())
1529+
flat_results = flatten_types(ft.result_types())
15281530
if opts.sync:
1529-
flat_params = flatten_types(ft.param_types())
15301531
if len(flat_params) > MAX_FLAT_PARAMS:
15311532
flat_params = ['i32']
1532-
1533-
flat_results = flatten_types(ft.result_types())
15341533
if len(flat_results) > MAX_FLAT_RESULTS:
15351534
match context:
15361535
case 'lift':
15371536
flat_results = ['i32']
15381537
case 'lower':
15391538
flat_params += ['i32']
15401539
flat_results = []
1541-
15421540
return CoreFuncType(flat_params, flat_results)
15431541
else:
15441542
match context:
15451543
case 'lift':
15461544
flat_params = []
15471545
flat_results = []
15481546
case 'lower':
1549-
flat_params = ['i32', 'i32']
1547+
if len(flat_params) > 1:
1548+
flat_params = ['i32']
1549+
if len(flat_results) > 0:
1550+
flat_params += ['i32']
15501551
flat_results = ['i32']
15511552
return CoreFuncType(flat_params, flat_results)
15521553

15531554
def flatten_types(ts):
15541555
return [ft for t in ts for ft in flatten_type(t)]
15551556
```
1556-
As shown here, the core signatures `async` functions are fixed and don't vary
1557+
As shown here, the core signatures `async` functions use a lower limit on the
1558+
maximum number of parameters (1) and results (0) passed as scalars before
1559+
falling back to passing through memory.
1560+
15571561
based on the function type (parameters and results are passed through memory
15581562
pointed to by the fixed `i32` parameters).
15591563

@@ -1859,32 +1863,17 @@ def lower_flat_flags(v, labels):
18591863

18601864
### Lifting and Lowering Values
18611865

1862-
The `lift_(sync|async)_values` functions define how to lift a list of core
1863-
parameters or results (given by the `CoreValueIter` `vi`) into a tuple of
1864-
component-level values with types `ts`. The sync and async variants differ in
1865-
how much they can pass in scalar "registers" before falling back to passing
1866-
values through linear memory: sync functions use up to `max_flat` scalars
1867-
whereas async functions have a single fixed `i32` that is either a single
1868-
scalar value or a pointer into linear memory:
1866+
The `lift_flat_values` function defines how to lift a list of core
1867+
parameters or results (given by the `CoreValueIter` `vi`) into a tuple
1868+
of component-level values with types `ts`.
18691869
```python
1870-
def lift_sync_values(cx, max_flat, vi, ts):
1870+
def lift_flat_values(cx, max_flat, vi, ts):
18711871
flat_types = flatten_types(ts)
18721872
if len(flat_types) > max_flat:
18731873
return lift_heap_values(cx, vi, ts)
18741874
else:
18751875
return [ lift_flat(cx, vi, t) for t in ts ]
18761876

1877-
def lift_async_values(cx, vi, ts):
1878-
if len(ts) == 0:
1879-
_ = vi.next('i32')
1880-
return []
1881-
flat_types = flatten_types(ts)
1882-
if len(flat_types) == 1 and flat_types[0] == 'i32':
1883-
assert(len(ts) == 1)
1884-
return [ lift_flat(cx, vi, ts[0]) ]
1885-
else:
1886-
return lift_heap_values(cx, vi, ts)
1887-
18881877
def lift_heap_values(cx, vi, ts):
18891878
ptr = vi.next('i32')
18901879
tuple_type = Tuple(ts)
@@ -1893,36 +1882,26 @@ def lift_heap_values(cx, vi, ts):
18931882
return list(load(cx, ptr, tuple_type).values())
18941883
```
18951884

1896-
Symmetrically, the `lower_(sync|async)_values` functions define how to lower a
1897-
list of component-level values `vs` of types `ts` into a list of core values.
1898-
As already described for [`flatten`](#flattening) above, lowering handles the
1899-
greater-than-`max_flat` case by either allocating storage with `realloc` or
1900-
accepting a caller-allocated buffer as an out-param:
1885+
Symmetrically, the `lower_flat_values` function defines how to lower a
1886+
list of component-level values `vs` of types `ts` into a list of core
1887+
values. As already described for [`flatten`](#flattening) above,
1888+
lowering handles the greater-than-`max_flat` case by either allocating
1889+
storage with `realloc` or accepting a caller-allocated buffer as an
1890+
out-param:
19011891
```python
1902-
def lower_sync_values(cx, max_flat, vs, ts, out_param = None):
1903-
inst = cx.inst
1904-
assert(inst.may_leave)
1905-
inst.may_leave = False
1892+
def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1893+
assert(cx.inst.may_leave)
1894+
cx.inst.may_leave = False
19061895
flat_types = flatten_types(ts)
19071896
if len(flat_types) > max_flat:
19081897
flat_vals = lower_heap_values(cx, vs, ts, out_param)
19091898
else:
19101899
flat_vals = []
19111900
for i in range(len(vs)):
19121901
flat_vals += lower_flat(cx, vs[i], ts[i])
1913-
inst.may_leave = True
1902+
cx.inst.may_leave = True
19141903
return flat_vals
19151904

1916-
def lower_async_values(cx, vs, ts, out_param):
1917-
if len(ts) == 0:
1918-
_ = out_param.next('i32')
1919-
return
1920-
inst = cx.inst
1921-
assert(inst.may_leave)
1922-
inst.may_leave = False
1923-
lower_heap_values(cx, vs, ts, out_param)
1924-
inst.may_leave = True
1925-
19261905
def lower_heap_values(cx, vs, ts, out_param):
19271906
tuple_type = Tuple(ts)
19281907
tuple_value = {str(i): v for i,v in enumerate(vs)}
@@ -1999,9 +1978,9 @@ async def canon_lift(opts, inst, callee, ft, caller, on_block, on_start, on_retu
19991978
task = SyncTask(opts, inst, caller, on_block)
20001979
await task.enter()
20011980

2002-
flat_args = lower_sync_values(task, MAX_FLAT_PARAMS, on_start(), ft.param_types())
1981+
flat_args = lower_flat_values(task, MAX_FLAT_PARAMS, on_start(), ft.param_types())
20031982
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
2004-
on_return(lift_sync_values(task, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
1983+
on_return(lift_flat_values(task, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
20051984

20061985
if opts.post_return is not None:
20071986
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
@@ -2095,10 +2074,10 @@ async def canon_lower(opts, callee, ft, task, flat_args):
20952074
task.on_block()
20962075
task.on_block = None
20972076
def on_start():
2098-
return lift_sync_values(subtask, MAX_FLAT_PARAMS, flat_args, ft.param_types())
2077+
return lift_flat_values(subtask, MAX_FLAT_PARAMS, flat_args, ft.param_types())
20992078
def on_return(results):
21002079
nonlocal flat_results
2101-
flat_results = lower_sync_values(subtask, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
2080+
flat_results = lower_flat_values(subtask, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
21022081
await callee(task, on_block, on_start, on_return)
21032082
task.inst.calling_sync_import = False
21042083
subtask.finish()
@@ -2110,10 +2089,10 @@ async def canon_lower(opts, callee, ft, task, flat_args):
21102089
eager_result.set_result('block')
21112090
def on_start():
21122091
subtask.start()
2113-
return lift_async_values(subtask, flat_args, ft.param_types())
2092+
return lift_flat_values(subtask, 1, flat_args, ft.param_types())
21142093
def on_return(results):
21152094
subtask.return_()
2116-
lower_async_values(subtask, results, ft.result_types(), flat_args)
2095+
lower_flat_values(subtask, 0, results, ft.result_types(), flat_args)
21172096
await callee(task, on_block, on_start, on_return)
21182097
subtask.finish()
21192098
if not eager_result.done():
@@ -2313,7 +2292,7 @@ async def canon_task_start(task, core_ft, flat_args):
23132292
trap_if(core_ft != flatten_functype(CanonicalOptions(), FuncType([], task.ft.params), 'lower'))
23142293
task.start()
23152294
args = task.on_start()
2316-
flat_results = lower_sync_values(task, MAX_FLAT_RESULTS, args, task.ft.param_types(), CoreValueIter(flat_args))
2295+
flat_results = lower_flat_values(task, MAX_FLAT_RESULTS, args, task.ft.param_types(), CoreValueIter(flat_args))
23172296
assert(len(core_ft.results) == len(flat_results))
23182297
return flat_results
23192298
```
@@ -2342,7 +2321,7 @@ async def canon_task_return(task, core_ft, flat_args):
23422321
trap_if(task.opts.sync)
23432322
trap_if(core_ft != flatten_functype(CanonicalOptions(), FuncType(task.ft.results, []), 'lower'))
23442323
task.return_()
2345-
results = lift_sync_values(task, MAX_FLAT_PARAMS, CoreValueIter(flat_args), task.ft.result_types())
2324+
results = lift_flat_values(task, MAX_FLAT_PARAMS, CoreValueIter(flat_args), task.ft.result_types())
23462325
task.on_return(results)
23472326
assert(len(core_ft.results) == 0)
23482327
return []

design/mvp/canonical-abi/definitions.py

Lines changed: 19 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1055,28 +1055,29 @@ def lower_borrow(cx, rep, t):
10551055
MAX_FLAT_RESULTS = 1
10561056

10571057
def flatten_functype(opts, ft, context):
1058+
flat_params = flatten_types(ft.param_types())
1059+
flat_results = flatten_types(ft.result_types())
10581060
if opts.sync:
1059-
flat_params = flatten_types(ft.param_types())
10601061
if len(flat_params) > MAX_FLAT_PARAMS:
10611062
flat_params = ['i32']
1062-
1063-
flat_results = flatten_types(ft.result_types())
10641063
if len(flat_results) > MAX_FLAT_RESULTS:
10651064
match context:
10661065
case 'lift':
10671066
flat_results = ['i32']
10681067
case 'lower':
10691068
flat_params += ['i32']
10701069
flat_results = []
1071-
10721070
return CoreFuncType(flat_params, flat_results)
10731071
else:
10741072
match context:
10751073
case 'lift':
10761074
flat_params = []
10771075
flat_results = []
10781076
case 'lower':
1079-
flat_params = ['i32', 'i32']
1077+
if len(flat_params) > 1:
1078+
flat_params = ['i32']
1079+
if len(flat_results) > 0:
1080+
flat_params += ['i32']
10801081
flat_results = ['i32']
10811082
return CoreFuncType(flat_params, flat_results)
10821083

@@ -1299,55 +1300,33 @@ def lower_flat_flags(v, labels):
12991300

13001301
### Lifting and Lowering Values
13011302

1302-
def lift_sync_values(cx, max_flat, vi, ts):
1303+
def lift_flat_values(cx, max_flat, vi, ts):
13031304
flat_types = flatten_types(ts)
13041305
if len(flat_types) > max_flat:
13051306
return lift_heap_values(cx, vi, ts)
13061307
else:
13071308
return [ lift_flat(cx, vi, t) for t in ts ]
13081309

1309-
def lift_async_values(cx, vi, ts):
1310-
if len(ts) == 0:
1311-
_ = vi.next('i32')
1312-
return []
1313-
flat_types = flatten_types(ts)
1314-
if len(flat_types) == 1 and flat_types[0] == 'i32':
1315-
assert(len(ts) == 1)
1316-
return [ lift_flat(cx, vi, ts[0]) ]
1317-
else:
1318-
return lift_heap_values(cx, vi, ts)
1319-
13201310
def lift_heap_values(cx, vi, ts):
13211311
ptr = vi.next('i32')
13221312
tuple_type = Tuple(ts)
13231313
trap_if(ptr != align_to(ptr, alignment(tuple_type)))
13241314
trap_if(ptr + elem_size(tuple_type) > len(cx.opts.memory))
13251315
return list(load(cx, ptr, tuple_type).values())
13261316

1327-
def lower_sync_values(cx, max_flat, vs, ts, out_param = None):
1328-
inst = cx.inst
1329-
assert(inst.may_leave)
1330-
inst.may_leave = False
1317+
def lower_flat_values(cx, max_flat, vs, ts, out_param = None):
1318+
assert(cx.inst.may_leave)
1319+
cx.inst.may_leave = False
13311320
flat_types = flatten_types(ts)
13321321
if len(flat_types) > max_flat:
13331322
flat_vals = lower_heap_values(cx, vs, ts, out_param)
13341323
else:
13351324
flat_vals = []
13361325
for i in range(len(vs)):
13371326
flat_vals += lower_flat(cx, vs[i], ts[i])
1338-
inst.may_leave = True
1327+
cx.inst.may_leave = True
13391328
return flat_vals
13401329

1341-
def lower_async_values(cx, vs, ts, out_param):
1342-
if len(ts) == 0:
1343-
_ = out_param.next('i32')
1344-
return
1345-
inst = cx.inst
1346-
assert(inst.may_leave)
1347-
inst.may_leave = False
1348-
lower_heap_values(cx, vs, ts, out_param)
1349-
inst.may_leave = True
1350-
13511330
def lower_heap_values(cx, vs, ts, out_param):
13521331
tuple_type = Tuple(ts)
13531332
tuple_value = {str(i): v for i,v in enumerate(vs)}
@@ -1369,9 +1348,9 @@ async def canon_lift(opts, inst, callee, ft, caller, on_block, on_start, on_retu
13691348
task = SyncTask(opts, inst, caller, on_block)
13701349
await task.enter()
13711350

1372-
flat_args = lower_sync_values(task, MAX_FLAT_PARAMS, on_start(), ft.param_types())
1351+
flat_args = lower_flat_values(task, MAX_FLAT_PARAMS, on_start(), ft.param_types())
13731352
flat_results = await call_and_trap_on_throw(callee, task, flat_args)
1374-
on_return(lift_sync_values(task, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
1353+
on_return(lift_flat_values(task, MAX_FLAT_RESULTS, CoreValueIter(flat_results), ft.result_types()))
13751354

13761355
if opts.post_return is not None:
13771356
[] = await call_and_trap_on_throw(opts.post_return, task, flat_results)
@@ -1419,10 +1398,10 @@ def on_block():
14191398
task.on_block()
14201399
task.on_block = None
14211400
def on_start():
1422-
return lift_sync_values(subtask, MAX_FLAT_PARAMS, flat_args, ft.param_types())
1401+
return lift_flat_values(subtask, MAX_FLAT_PARAMS, flat_args, ft.param_types())
14231402
def on_return(results):
14241403
nonlocal flat_results
1425-
flat_results = lower_sync_values(subtask, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
1404+
flat_results = lower_flat_values(subtask, MAX_FLAT_RESULTS, results, ft.result_types(), flat_args)
14261405
await callee(task, on_block, on_start, on_return)
14271406
task.inst.calling_sync_import = False
14281407
subtask.finish()
@@ -1434,10 +1413,10 @@ def on_block():
14341413
eager_result.set_result('block')
14351414
def on_start():
14361415
subtask.start()
1437-
return lift_async_values(subtask, flat_args, ft.param_types())
1416+
return lift_flat_values(subtask, 1, flat_args, ft.param_types())
14381417
def on_return(results):
14391418
subtask.return_()
1440-
lower_async_values(subtask, results, ft.result_types(), flat_args)
1419+
lower_flat_values(subtask, 0, results, ft.result_types(), flat_args)
14411420
await callee(task, on_block, on_start, on_return)
14421421
subtask.finish()
14431422
if not eager_result.done():
@@ -1510,7 +1489,7 @@ async def canon_task_start(task, core_ft, flat_args):
15101489
trap_if(core_ft != flatten_functype(CanonicalOptions(), FuncType([], task.ft.params), 'lower'))
15111490
task.start()
15121491
args = task.on_start()
1513-
flat_results = lower_sync_values(task, MAX_FLAT_RESULTS, args, task.ft.param_types(), CoreValueIter(flat_args))
1492+
flat_results = lower_flat_values(task, MAX_FLAT_RESULTS, args, task.ft.param_types(), CoreValueIter(flat_args))
15141493
assert(len(core_ft.results) == len(flat_results))
15151494
return flat_results
15161495

@@ -1521,7 +1500,7 @@ async def canon_task_return(task, core_ft, flat_args):
15211500
trap_if(task.opts.sync)
15221501
trap_if(core_ft != flatten_functype(CanonicalOptions(), FuncType(task.ft.results, []), 'lower'))
15231502
task.return_()
1524-
results = lift_sync_values(task, MAX_FLAT_PARAMS, CoreValueIter(flat_args), task.ft.result_types())
1503+
results = lift_flat_values(task, MAX_FLAT_PARAMS, CoreValueIter(flat_args), task.ft.result_types())
15251504
task.on_return(results)
15261505
assert(len(core_ft.results) == 0)
15271506
return []

design/mvp/canonical-abi/run_tests.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,13 +359,13 @@ async def callee(task, x):
359359

360360
asyncio.run(caller_task.enter())
361361

362-
flat_args = lower_sync_values(caller_task, definitions.MAX_FLAT_PARAMS, [v], [t])
362+
flat_args = lower_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, [v], [t])
363363
if return_in_heap:
364364
flat_args += [ caller_heap.realloc(0, 0, alignment(t), elem_size(t)) ]
365365
flat_results = asyncio.run(canon_lower(caller_opts, lifted_callee, ft, caller_task, flat_args))
366366
if return_in_heap:
367367
flat_results = [ flat_args[-1] ]
368-
[got] = lift_sync_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t])
368+
[got] = lift_flat_values(caller_task, definitions.MAX_FLAT_PARAMS, CoreValueIter(flat_results), [t])
369369
caller_task.exit()
370370

371371
if got != v:

0 commit comments

Comments
 (0)