Skip to content

Commit e3ea993

Browse files
committed
Cache inner contexts of overloads and binary ops
1 parent 4a427e9 commit e3ea993

File tree

1 file changed

+59
-36
lines changed

1 file changed

+59
-36
lines changed

mypy/checkexpr.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,9 @@ def __init__(
356356

357357
self._arg_infer_context_cache = None
358358

359+
self.overload_stack_depth = 0
360+
self._args_cache = {}
361+
359362
def reset(self) -> None:
360363
self.resolved_type = {}
361364

@@ -1613,9 +1616,10 @@ def check_call(
16131616
object_type,
16141617
)
16151618
elif isinstance(callee, Overloaded):
1616-
return self.check_overload_call(
1617-
callee, args, arg_kinds, arg_names, callable_name, object_type, context
1618-
)
1619+
with self.overload_context(callee.name()):
1620+
return self.check_overload_call(
1621+
callee, args, arg_kinds, arg_names, callable_name, object_type, context
1622+
)
16191623
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
16201624
return self.check_any_type_call(args, callee)
16211625
elif isinstance(callee, UnionType):
@@ -1674,6 +1678,14 @@ def check_call(
16741678
else:
16751679
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
16761680

1681+
@contextmanager
1682+
def overload_context(self, fn):
1683+
self.overload_stack_depth += 1
1684+
yield
1685+
self.overload_stack_depth -= 1
1686+
if self.overload_stack_depth == 0:
1687+
self._args_cache.clear()
1688+
16771689
def check_callable_call(
16781690
self,
16791691
callee: CallableType,
@@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
19371949
In short, we basically recurse on each argument without considering
19381950
in what context the argument was called.
19391951
"""
1952+
can_cache = not any(isinstance(t, TempNode) for t in args)
1953+
key = tuple(map(id, args))
1954+
if can_cache and key in self._args_cache:
1955+
return self._args_cache[key]
19401956
res: list[Type] = []
19411957

19421958
for arg in args:
@@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
19451961
res.append(NoneType())
19461962
else:
19471963
res.append(arg_type)
1964+
if can_cache:
1965+
self._args_cache[key] = res
19481966
return res
19491967

19501968
def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
@@ -2917,17 +2935,16 @@ def infer_overload_return_type(
29172935

29182936
for typ in plausible_targets:
29192937
assert self.msg is self.chk.msg
2920-
with self.msg.filter_errors() as w:
2921-
with self.chk.local_type_map() as m:
2922-
ret_type, infer_type = self.check_call(
2923-
callee=typ,
2924-
args=args,
2925-
arg_kinds=arg_kinds,
2926-
arg_names=arg_names,
2927-
context=context,
2928-
callable_name=callable_name,
2929-
object_type=object_type,
2930-
)
2938+
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
2939+
ret_type, infer_type = self.check_call(
2940+
callee=typ,
2941+
args=args,
2942+
arg_kinds=arg_kinds,
2943+
arg_names=arg_names,
2944+
context=context,
2945+
callable_name=callable_name,
2946+
object_type=object_type,
2947+
)
29312948
is_match = not w.has_new_errors()
29322949
if is_match:
29332950
# Return early if possible; otherwise record info, so we can
@@ -3474,6 +3491,10 @@ def visit_op_expr(self, e: OpExpr) -> Type:
34743491
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
34753492
if isinstance(e.left, StrExpr):
34763493
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
3494+
3495+
key = id(e)
3496+
if key in self._args_cache:
3497+
return self._args_cache[key]
34773498
left_type = self.accept(e.left)
34783499

34793500
proper_left_type = get_proper_type(left_type)
@@ -3543,28 +3564,30 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35433564
)
35443565

35453566
if e.op in operators.op_methods:
3546-
method = operators.op_methods[e.op]
3547-
if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER:
3548-
result, method_type = self.check_op(
3549-
method,
3550-
base_type=left_type,
3551-
arg=e.right,
3552-
context=e,
3553-
allow_reverse=use_reverse is UseReverse.DEFAULT,
3554-
)
3555-
elif use_reverse is UseReverse.ALWAYS:
3556-
result, method_type = self.check_op(
3557-
# The reverse operator here gives better error messages:
3558-
operators.reverse_op_methods[method],
3559-
base_type=self.accept(e.right),
3560-
arg=e.left,
3561-
context=e,
3562-
allow_reverse=False,
3563-
)
3564-
else:
3565-
assert_never(use_reverse)
3566-
e.method_type = method_type
3567-
return result
3567+
with self.overload_context(e.op):
3568+
method = operators.op_methods[e.op]
3569+
if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER:
3570+
result, method_type = self.check_op(
3571+
method,
3572+
base_type=left_type,
3573+
arg=e.right,
3574+
context=e,
3575+
allow_reverse=use_reverse is UseReverse.DEFAULT,
3576+
)
3577+
elif use_reverse is UseReverse.ALWAYS:
3578+
result, method_type = self.check_op(
3579+
# The reverse operator here gives better error messages:
3580+
operators.reverse_op_methods[method],
3581+
base_type=self.accept(e.right),
3582+
arg=e.left,
3583+
context=e,
3584+
allow_reverse=False,
3585+
)
3586+
else:
3587+
assert_never(use_reverse)
3588+
e.method_type = method_type
3589+
self._args_cache[key] = result
3590+
return result
35683591
else:
35693592
raise RuntimeError(f"Unknown operator {e.op}")
35703593

0 commit comments

Comments
 (0)