Skip to content

WIP: try to cache inner contexts of overloads #19408

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 32 additions & 14 deletions mypy/checkexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,9 @@ def __init__(

self._arg_infer_context_cache = None

self.overload_stack_depth = 0
self._args_cache: dict[tuple[int, ...], list[Type]] = {}

def reset(self) -> None:
self.resolved_type = {}

Expand Down Expand Up @@ -1613,9 +1616,10 @@ def check_call(
object_type,
)
elif isinstance(callee, Overloaded):
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
with self.overload_context():
return self.check_overload_call(
callee, args, arg_kinds, arg_names, callable_name, object_type, context
)
elif isinstance(callee, AnyType) or not self.chk.in_checked_function():
return self.check_any_type_call(args, callee)
elif isinstance(callee, UnionType):
Expand Down Expand Up @@ -1674,6 +1678,14 @@ def check_call(
else:
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)

@contextmanager
def overload_context(self) -> Iterator[None]:
self.overload_stack_depth += 1
yield
self.overload_stack_depth -= 1
if self.overload_stack_depth == 0:
self._args_cache.clear()

def check_callable_call(
self,
callee: CallableType,
Expand Down Expand Up @@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
In short, we basically recurse on each argument without considering
in what context the argument was called.
"""
can_cache = not any(isinstance(t, TempNode) for t in args)
key = tuple(map(id, args))
if can_cache and key in self._args_cache:
return self._args_cache[key]
res: list[Type] = []

for arg in args:
Expand All @@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
res.append(NoneType())
else:
res.append(arg_type)
if can_cache:
self._args_cache[key] = res
return res

def infer_more_unions_for_recursive_type(self, type_context: Type) -> bool:
Expand Down Expand Up @@ -2917,17 +2935,16 @@ def infer_overload_return_type(

for typ in plausible_targets:
assert self.msg is self.chk.msg
with self.msg.filter_errors() as w:
with self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
with self.msg.filter_errors() as w, self.chk.local_type_map() as m:
ret_type, infer_type = self.check_call(
callee=typ,
args=args,
arg_kinds=arg_kinds,
arg_names=arg_names,
context=context,
callable_name=callable_name,
object_type=object_type,
)
is_match = not w.has_new_errors()
if is_match:
# Return early if possible; otherwise record info, so we can
Expand Down Expand Up @@ -3474,6 +3491,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
if isinstance(e.left, StrExpr):
return self.strfrm_checker.check_str_interpolation(e.left, e.right)

left_type = self.accept(e.left)

proper_left_type = get_proper_type(left_type)
Expand Down