From e3ea99303a1fc86f25db339bc0756db6ab66e935 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 8 Jul 2025 21:39:56 +0200 Subject: [PATCH 1/5] Cache inner contexts of overloads and binary ops --- mypy/checkexpr.py | 95 +++++++++++++++++++++++++++++------------------ 1 file changed, 59 insertions(+), 36 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..65d433a58a1f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -356,6 +356,9 @@ def __init__( self._arg_infer_context_cache = None + self.overload_stack_depth = 0 + self._args_cache = {} + def reset(self) -> None: self.resolved_type = {} @@ -1613,9 +1616,10 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - return self.check_overload_call( - callee, args, arg_kinds, arg_names, callable_name, object_type, context - ) + with self.overload_context(callee.name()): + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): @@ -1674,6 +1678,14 @@ def check_call( else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + @contextmanager + def overload_context(self, fn): + self.overload_stack_depth += 1 + yield + self.overload_stack_depth -= 1 + if self.overload_stack_depth == 0: + self._args_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] In short, we basically recurse on each argument without considering in what context the argument was called. """ + can_cache = not any(isinstance(t, TempNode) for t in args) + key = tuple(map(id, args)) + if can_cache and key in self._args_cache: + return self._args_cache[key] res: list[Type] = [] for arg in args: @@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] res.append(NoneType()) else: res.append(arg_type) + if can_cache: + self._args_cache[key] = res return res def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: @@ -2917,17 +2935,16 @@ def infer_overload_return_type( for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info, so we can @@ -3474,6 +3491,10 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) + + key = id(e) + if key in self._args_cache: + return self._args_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3543,28 +3564,30 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - method = operators.op_methods[e.op] - if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: - result, method_type = self.check_op( - method, - base_type=left_type, - arg=e.right, - context=e, - allow_reverse=use_reverse is UseReverse.DEFAULT, - ) - elif use_reverse is UseReverse.ALWAYS: - result, method_type = self.check_op( - # The reverse operator here gives better error messages: - operators.reverse_op_methods[method], - base_type=self.accept(e.right), - arg=e.left, - context=e, - allow_reverse=False, - ) - else: - assert_never(use_reverse) - e.method_type = method_type - return result + with self.overload_context(e.op): + method = operators.op_methods[e.op] + if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: + result, method_type = self.check_op( + method, + base_type=left_type, + arg=e.right, + context=e, + allow_reverse=use_reverse is UseReverse.DEFAULT, + ) + elif use_reverse is UseReverse.ALWAYS: + result, method_type = self.check_op( + # The reverse operator here gives better error messages: + operators.reverse_op_methods[method], + base_type=self.accept(e.right), + arg=e.left, + context=e, + allow_reverse=False, + ) + else: + assert_never(use_reverse) + e.method_type = method_type + self._args_cache[key] = result + return result else: raise RuntimeError(f"Unknown operator {e.op}") From 03af5b21d2c1c93b25fdd381687cc722bcf1c062 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Tue, 8 Jul 2025 21:59:14 +0200 Subject: [PATCH 2/5] Fix selfcheck --- mypy/checkexpr.py | 24 +++++++++++++++++------- 1 file changed, 17 insertions(+), 7 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 65d433a58a1f..89d0d0d75b61 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -357,7 +357,9 @@ def __init__( self._arg_infer_context_cache = None self.overload_stack_depth = 0 - self._args_cache = {} + self.ops_stack_depth = 0 + self._args_cache: dict[tuple[int, ...], list[Type]] = {} + self._ops_cache: dict[int, Type] = {} def reset(self) -> None: self.resolved_type = {} @@ -1616,7 +1618,7 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - with self.overload_context(callee.name()): + with self.overload_context(): return self.check_overload_call( callee, args, arg_kinds, arg_names, callable_name, object_type, context ) @@ -1679,13 +1681,21 @@ def check_call( return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) @contextmanager - def overload_context(self, fn): + def overload_context(self) -> Iterator[None]: self.overload_stack_depth += 1 yield self.overload_stack_depth -= 1 if self.overload_stack_depth == 0: self._args_cache.clear() + @contextmanager + def ops_context(self) -> Iterator[None]: + self.ops_stack_depth += 1 + yield + self.ops_stack_depth -= 1 + if self.ops_stack_depth == 0: + self._ops_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -3493,8 +3503,8 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) key = id(e) - if key in self._args_cache: - return self._args_cache[key] + if key in self._ops_cache: + return self._ops_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3564,7 +3574,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - with self.overload_context(e.op): + with self.ops_context(): method = operators.op_methods[e.op] if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: result, method_type = self.check_op( @@ -3586,7 +3596,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: else: assert_never(use_reverse) e.method_type = method_type - self._args_cache[key] = result + self._ops_cache[key] = result return result else: raise RuntimeError(f"Unknown operator {e.op}") From 665a3119605ca9b3909ac37d2320385cef1f64ef Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 00:19:08 +0200 Subject: [PATCH 3/5] Only retain the overloads part --- mypy/checkexpr.py | 59 ++++++++++++++++++----------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 89d0d0d75b61..b7292d8c82c3 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -357,9 +357,7 @@ def __init__( self._arg_infer_context_cache = None self.overload_stack_depth = 0 - self.ops_stack_depth = 0 self._args_cache: dict[tuple[int, ...], list[Type]] = {} - self._ops_cache: dict[int, Type] = {} def reset(self) -> None: self.resolved_type = {} @@ -1688,14 +1686,6 @@ def overload_context(self) -> Iterator[None]: if self.overload_stack_depth == 0: self._args_cache.clear() - @contextmanager - def ops_context(self) -> Iterator[None]: - self.ops_stack_depth += 1 - yield - self.ops_stack_depth -= 1 - if self.ops_stack_depth == 0: - self._ops_cache.clear() - def check_callable_call( self, callee: CallableType, @@ -3502,9 +3492,6 @@ def visit_op_expr(self, e: OpExpr) -> Type: if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) - key = id(e) - if key in self._ops_cache: - return self._ops_cache[key] left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -3574,30 +3561,28 @@ def visit_op_expr(self, e: OpExpr) -> Type: ) if e.op in operators.op_methods: - with self.ops_context(): - method = operators.op_methods[e.op] - if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: - result, method_type = self.check_op( - method, - base_type=left_type, - arg=e.right, - context=e, - allow_reverse=use_reverse is UseReverse.DEFAULT, - ) - elif use_reverse is UseReverse.ALWAYS: - result, method_type = self.check_op( - # The reverse operator here gives better error messages: - operators.reverse_op_methods[method], - base_type=self.accept(e.right), - arg=e.left, - context=e, - allow_reverse=False, - ) - else: - assert_never(use_reverse) - e.method_type = method_type - self._ops_cache[key] = result - return result + method = operators.op_methods[e.op] + if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER: + result, method_type = self.check_op( + method, + base_type=left_type, + arg=e.right, + context=e, + allow_reverse=use_reverse is UseReverse.DEFAULT, + ) + elif use_reverse is UseReverse.ALWAYS: + result, method_type = self.check_op( + # The reverse operator here gives better error messages: + operators.reverse_op_methods[method], + base_type=self.accept(e.right), + arg=e.left, + context=e, + allow_reverse=False, + ) + else: + assert_never(use_reverse) + e.method_type = method_type + return result else: raise RuntimeError(f"Unknown operator {e.op}") From a1f8c7721e28d42089353420b727a5f4ee4ff099 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 16:57:36 +0200 Subject: [PATCH 4/5] Fix: the cache should not be touched outside of overloads --- mypy/checkexpr.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index b7292d8c82c3..ce1c58c38b21 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -360,6 +360,8 @@ def __init__( self._args_cache: dict[tuple[int, ...], list[Type]] = {} def reset(self) -> None: + assert self.overload_stack_depth == 0 + assert not self._args_cache self.resolved_type = {} def visit_name_expr(self, e: NameExpr) -> Type: @@ -1949,7 +1951,12 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] In short, we basically recurse on each argument without considering in what context the argument was called. """ - can_cache = not any(isinstance(t, TempNode) for t in args) + # We can only use this hack locally while checking a single nested overloaded + # call. This saves a lot of rechecking, but is not generally safe. Cache is + # pruned upon leaving the outermost overload. + can_cache = self.overload_stack_depth > 0 and not any( + isinstance(t, TempNode) for t in args + ) key = tuple(map(id, args)) if can_cache and key in self._args_cache: return self._args_cache[key] From 479550b97e33b0ef0c361042cc9610e4f77581c0 Mon Sep 17 00:00:00 2001 From: STerliakov Date: Wed, 9 Jul 2025 18:14:05 +0200 Subject: [PATCH 5/5] Poison cache when we encounter any lambda --- mypy/checkexpr.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index ce1c58c38b21..95c863407b37 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -233,6 +233,8 @@ "builtins.memoryview", } +POISON_KEY: Final = (-1,) + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -1954,8 +1956,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] # We can only use this hack locally while checking a single nested overloaded # call. This saves a lot of rechecking, but is not generally safe. Cache is # pruned upon leaving the outermost overload. - can_cache = self.overload_stack_depth > 0 and not any( - isinstance(t, TempNode) for t in args + can_cache = ( + self.overload_stack_depth > 0 + and POISON_KEY not in self._args_cache + and not any(isinstance(t, TempNode) for t in args) ) key = tuple(map(id, args)) if can_cache and key in self._args_cache: @@ -5426,6 +5430,9 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + if self.overload_stack_depth > 0: + # Poison cache when we encounter lambdas - it isn't safe to cache their types. + self._args_cache[POISON_KEY] = [] self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: