Skip to content

Commit 4a427e9

Browse files
authored
[mypyc] Speed up native-to-native calls using await (#19398)
When calling a native async function using `await`, e.g. `await foo()`, avoid raising `StopIteration` to pass the return value, since this is expensive. Instead, pass an extra `PyObject **` argument to the generator helper method and use that to return the return value. This is mostly helpful when there are many calls using await that don't block (e.g. there is a fast path that is usually taken that doesn't block). When awaiting from non-compiled code, the slow path is still taken. This builds on top of #19376. This PR makes this microbenchmark about 3x faster, which is about the ideal scenario for this optimization: ``` import asyncio from time import time async def inc(x: int) -> int: return x + 1 async def bench(n: int) -> int: x = 0 for i in range(n): x = await inc(x) return x asyncio.run(bench(1000)) t0 = time() asyncio.run(bench(1000 * 1000 * 200)) print(time() - t0) ```
1 parent 503f5bd commit 4a427e9

File tree

7 files changed

+101
-11
lines changed

7 files changed

+101
-11
lines changed

mypyc/irbuild/context.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,11 @@ def __init__(self, ir: ClassIR) -> None:
167167
# Holds the arg passed to send
168168
self.send_arg_reg: Value | None = None
169169

170+
# Holds the PyObject ** pointer through which return value can be passed
171+
# instead of raising StopIteration(ret_value) (only if not NULL). This
172+
# is used for faster native-to-native calls.
173+
self.stop_iter_value_reg: Value | None = None
174+
170175
# The switch block is used to decide which instruction to go using the value held in the
171176
# next-label register.
172177
self.switch_block = BasicBlock()

mypyc/irbuild/generator.py

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,12 @@
3232
Unreachable,
3333
Value,
3434
)
35-
from mypyc.ir.rtypes import RInstance, int32_rprimitive, object_rprimitive
35+
from mypyc.ir.rtypes import (
36+
RInstance,
37+
int32_rprimitive,
38+
object_pointer_rprimitive,
39+
object_rprimitive,
40+
)
3641
from mypyc.irbuild.builder import IRBuilder, calculate_arg_defaults, gen_arg_defaults
3742
from mypyc.irbuild.context import FuncInfo, GeneratorClass
3843
from mypyc.irbuild.env_class import (
@@ -256,7 +261,14 @@ def add_next_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl:
256261
result = builder.add(
257262
Call(
258263
fn_decl,
259-
[builder.self(), none_reg, none_reg, none_reg, none_reg],
264+
[
265+
builder.self(),
266+
none_reg,
267+
none_reg,
268+
none_reg,
269+
none_reg,
270+
Integer(0, object_pointer_rprimitive),
271+
],
260272
fn_info.fitem.line,
261273
)
262274
)
@@ -272,7 +284,14 @@ def add_send_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl:
272284
result = builder.add(
273285
Call(
274286
fn_decl,
275-
[builder.self(), none_reg, none_reg, none_reg, builder.read(arg)],
287+
[
288+
builder.self(),
289+
none_reg,
290+
none_reg,
291+
none_reg,
292+
builder.read(arg),
293+
Integer(0, object_pointer_rprimitive),
294+
],
276295
fn_info.fitem.line,
277296
)
278297
)
@@ -297,7 +316,14 @@ def add_throw_to_generator_class(builder: IRBuilder, fn_info: FuncInfo, fn_decl:
297316
result = builder.add(
298317
Call(
299318
fn_decl,
300-
[builder.self(), builder.read(typ), builder.read(val), builder.read(tb), none_reg],
319+
[
320+
builder.self(),
321+
builder.read(typ),
322+
builder.read(val),
323+
builder.read(tb),
324+
none_reg,
325+
Integer(0, object_pointer_rprimitive),
326+
],
301327
fn_info.fitem.line,
302328
)
303329
)
@@ -377,8 +403,15 @@ def setup_env_for_generator_class(builder: IRBuilder) -> None:
377403
# TODO: Use the right type here instead of object?
378404
exc_arg = builder.add_local(Var("arg"), object_rprimitive, is_arg=True)
379405

406+
# Parameter that can used to pass a pointer which can used instead of
407+
# raising StopIteration(value). If the value is NULL, this won't be used.
408+
stop_iter_value_arg = builder.add_local(
409+
Var("stop_iter_ptr"), object_pointer_rprimitive, is_arg=True
410+
)
411+
380412
cls.exc_regs = (exc_type, exc_val, exc_tb)
381413
cls.send_arg_reg = exc_arg
414+
cls.stop_iter_value_reg = stop_iter_value_arg
382415

383416
cls.self_reg = builder.read(self_target, fitem.line)
384417
if builder.fn_info.can_merge_generator_and_env_classes():

mypyc/irbuild/nonlocalcontrol.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,11 @@
1616
Integer,
1717
Register,
1818
Return,
19+
SetMem,
1920
Unreachable,
2021
Value,
2122
)
23+
from mypyc.ir.rtypes import object_rprimitive
2224
from mypyc.irbuild.targets import AssignmentTarget
2325
from mypyc.primitives.exc_ops import restore_exc_info_op, set_stop_iteration_value
2426

@@ -108,10 +110,27 @@ def gen_return(self, builder: IRBuilder, value: Value, line: int) -> None:
108110
# StopIteration instead of using RaiseStandardError because
109111
# the obvious thing doesn't work if the value is a tuple
110112
# (???).
113+
114+
true, false = BasicBlock(), BasicBlock()
115+
stop_iter_reg = builder.fn_info.generator_class.stop_iter_value_reg
116+
assert stop_iter_reg is not None
117+
118+
builder.add(Branch(stop_iter_reg, true, false, Branch.IS_ERROR))
119+
120+
builder.activate_block(true)
121+
# The default/slow path is to raise a StopIteration exception with
122+
# return value.
111123
builder.call_c(set_stop_iteration_value, [value], NO_TRACEBACK_LINE_NO)
112124
builder.add(Unreachable())
113125
builder.builder.pop_error_handler()
114126

127+
builder.activate_block(false)
128+
# The fast path is to store return value via caller-provided pointer
129+
# instead of raising an exception. This can only be used when the
130+
# caller is a native function.
131+
builder.add(SetMem(object_rprimitive, stop_iter_reg, value))
132+
builder.add(Return(Integer(0, object_rprimitive)))
133+
115134

116135
class CleanupNonlocalControl(NonlocalControl):
117136
"""Abstract nonlocal control that runs some cleanup code."""

mypyc/irbuild/prepare.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
RType,
5757
dict_rprimitive,
5858
none_rprimitive,
59+
object_pointer_rprimitive,
5960
object_rprimitive,
6061
tuple_rprimitive,
6162
)
@@ -220,6 +221,8 @@ def create_generator_class_if_needed(
220221
RuntimeArg("value", object_rprimitive),
221222
RuntimeArg("traceback", object_rprimitive),
222223
RuntimeArg("arg", object_rprimitive),
224+
# If non-NULL, used to store return value instead of raising StopIteration(retv)
225+
RuntimeArg("stop_iter_ptr", object_pointer_rprimitive),
223226
),
224227
object_rprimitive,
225228
)

mypyc/irbuild/statement.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@
103103
get_exc_info_op,
104104
get_exc_value_op,
105105
keep_propagating_op,
106+
propagate_if_error_op,
106107
raise_exception_op,
107108
reraise_exception_op,
108109
restore_exc_info_op,
@@ -958,21 +959,34 @@ def emit_yield_from_or_await(
958959

959960
if isinstance(iter_reg.type, RInstance) and iter_reg.type.class_ir.has_method(helper_method):
960961
# Second fast path optimization: call helper directly (see also comment above).
962+
#
963+
# Calling a generated generator, so avoid raising StopIteration by passing
964+
# an extra PyObject ** argument to helper where the stop iteration value is stored.
965+
fast_path = True
961966
obj = builder.read(iter_reg)
962967
nn = builder.none_object()
963-
m = MethodCall(obj, helper_method, [nn, nn, nn, nn], line)
968+
stop_iter_val = Register(object_rprimitive)
969+
err = builder.add(LoadErrorValue(object_rprimitive, undefines=True))
970+
builder.assign(stop_iter_val, err, line)
971+
ptr = builder.add(LoadAddress(object_pointer_rprimitive, stop_iter_val))
972+
m = MethodCall(obj, helper_method, [nn, nn, nn, nn, ptr], line)
964973
# Generators have custom error handling, so disable normal error handling.
965974
m.error_kind = ERR_NEVER
966975
_y_init = builder.add(m)
967976
else:
977+
fast_path = False
968978
_y_init = builder.call_c(next_raw_op, [builder.read(iter_reg)], line)
969979

970980
builder.add(Branch(_y_init, stop_block, main_block, Branch.IS_ERROR))
971981

972-
# Try extracting a return value from a StopIteration and return it.
973-
# If it wasn't, this reraises the exception.
974982
builder.activate_block(stop_block)
975-
builder.assign(result, builder.call_c(check_stop_op, [], line), line)
983+
if fast_path:
984+
builder.primitive_op(propagate_if_error_op, [stop_iter_val], line)
985+
builder.assign(result, stop_iter_val, line)
986+
else:
987+
# Try extracting a return value from a StopIteration and return it.
988+
# If it wasn't, this reraises the exception.
989+
builder.assign(result, builder.call_c(check_stop_op, [], line), line)
976990
# Clear the spilled iterator/coroutine so that it will be freed.
977991
# Otherwise, the freeing of the spilled register would likely be delayed.
978992
err = builder.add(LoadErrorValue(iter_reg.type))

mypyc/lower/misc_ops.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

3-
from mypyc.ir.ops import GetElementPtr, LoadMem, Value
4-
from mypyc.ir.rtypes import PyVarObject, c_pyssize_t_rprimitive
3+
from mypyc.ir.ops import ComparisonOp, GetElementPtr, Integer, LoadMem, Value
4+
from mypyc.ir.rtypes import PyVarObject, c_pyssize_t_rprimitive, object_rprimitive
55
from mypyc.irbuild.ll_builder import LowLevelIRBuilder
66
from mypyc.lower.registry import lower_primitive_op
77

@@ -10,3 +10,9 @@
1010
def var_object_size(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
1111
elem_address = builder.add(GetElementPtr(args[0], PyVarObject, "ob_size"))
1212
return builder.add(LoadMem(c_pyssize_t_rprimitive, elem_address))
13+
14+
15+
@lower_primitive_op("propagate_if_error")
16+
def propagate_if_error_op(builder: LowLevelIRBuilder, args: list[Value], line: int) -> Value:
17+
# Return False on NULL. The primitive uses ERR_FALSE, so this is an error.
18+
return builder.add(ComparisonOp(args[0], Integer(0, object_rprimitive), ComparisonOp.NEQ))

mypyc/primitives/exc_ops.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
from mypyc.ir.ops import ERR_ALWAYS, ERR_FALSE, ERR_NEVER
66
from mypyc.ir.rtypes import bit_rprimitive, exc_rtuple, object_rprimitive, void_rtype
7-
from mypyc.primitives.registry import custom_op
7+
from mypyc.primitives.registry import custom_op, custom_primitive_op
88

99
# If the argument is a class, raise an instance of the class. Otherwise, assume
1010
# that the argument is an exception object, and raise it.
@@ -62,6 +62,16 @@
6262
error_kind=ERR_FALSE,
6363
)
6464

65+
# If argument is NULL, propagate currently raised exception (in this case
66+
# an exception must have been raised). If this can be used, it's faster
67+
# than using PyErr_Occurred().
68+
propagate_if_error_op = custom_primitive_op(
69+
"propagate_if_error",
70+
arg_types=[object_rprimitive],
71+
return_type=bit_rprimitive,
72+
error_kind=ERR_FALSE,
73+
)
74+
6575
# Catches a propagating exception and makes it the "currently
6676
# handled exception" (by sticking it into sys.exc_info()). Returns the
6777
# exception that was previously being handled, which must be restored

0 commit comments

Comments
 (0)