diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..95c863407b37 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -233,6 +233,8 @@ "builtins.memoryview", } +POISON_KEY: Final = (-1,) + class TooManyUnions(Exception): """Indicates that we need to stop splitting unions in an attempt @@ -356,7 +358,12 @@ def __init__( self._arg_infer_context_cache = None + self.overload_stack_depth = 0 + self._args_cache: dict[tuple[int, ...], list[Type]] = {} + def reset(self) -> None: + assert self.overload_stack_depth == 0 + assert not self._args_cache self.resolved_type = {} def visit_name_expr(self, e: NameExpr) -> Type: @@ -1613,9 +1620,10 @@ def check_call( object_type, ) elif isinstance(callee, Overloaded): - return self.check_overload_call( - callee, args, arg_kinds, arg_names, callable_name, object_type, context - ) + with self.overload_context(): + return self.check_overload_call( + callee, args, arg_kinds, arg_names, callable_name, object_type, context + ) elif isinstance(callee, AnyType) or not self.chk.in_checked_function(): return self.check_any_type_call(args, callee) elif isinstance(callee, UnionType): @@ -1674,6 +1682,14 @@ def check_call( else: return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error) + @contextmanager + def overload_context(self) -> Iterator[None]: + self.overload_stack_depth += 1 + yield + self.overload_stack_depth -= 1 + if self.overload_stack_depth == 0: + self._args_cache.clear() + def check_callable_call( self, callee: CallableType, @@ -1937,6 +1953,17 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] In short, we basically recurse on each argument without considering in what context the argument was called. """ + # We can only use this hack locally while checking a single nested overloaded + # call. This saves a lot of rechecking, but is not generally safe. Cache is + # pruned upon leaving the outermost overload. + can_cache = ( + self.overload_stack_depth > 0 + and POISON_KEY not in self._args_cache + and not any(isinstance(t, TempNode) for t in args) + ) + key = tuple(map(id, args)) + if can_cache and key in self._args_cache: + return self._args_cache[key] res: list[Type] = [] for arg in args: @@ -1945,6 +1972,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type] res.append(NoneType()) else: res.append(arg_type) + if can_cache: + self._args_cache[key] = res return res def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool: @@ -2917,17 +2946,16 @@ def infer_overload_return_type( for typ in plausible_targets: assert self.msg is self.chk.msg - with self.msg.filter_errors() as w: - with self.chk.local_type_map() as m: - ret_type, infer_type = self.check_call( - callee=typ, - args=args, - arg_kinds=arg_kinds, - arg_names=arg_names, - context=context, - callable_name=callable_name, - object_type=object_type, - ) + with self.msg.filter_errors() as w, self.chk.local_type_map() as m: + ret_type, infer_type = self.check_call( + callee=typ, + args=args, + arg_kinds=arg_kinds, + arg_names=arg_names, + context=context, + callable_name=callable_name, + object_type=object_type, + ) is_match = not w.has_new_errors() if is_match: # Return early if possible; otherwise record info, so we can @@ -3474,6 +3502,7 @@ def visit_op_expr(self, e: OpExpr) -> Type: return self.strfrm_checker.check_str_interpolation(e.left, e.right) if isinstance(e.left, StrExpr): return self.strfrm_checker.check_str_interpolation(e.left, e.right) + left_type = self.accept(e.left) proper_left_type = get_proper_type(left_type) @@ -5401,6 +5430,9 @@ def find_typeddict_context( def visit_lambda_expr(self, e: LambdaExpr) -> Type: """Type check lambda expression.""" + if self.overload_stack_depth > 0: + # Poison cache when we encounter lambdas - it isn't safe to cache their types. + self._args_cache[POISON_KEY] = [] self.chk.check_default_args(e, body_is_trivial=False) inferred_type, type_override = self.infer_lambda_type_using_context(e) if not inferred_type: