@@ -360,6 +360,8 @@ def __init__(
360
360
self ._args_cache : dict [tuple [int , ...], list [Type ]] = {}
361
361
362
362
def reset (self ) -> None :
363
+ assert self .overload_stack_depth == 0
364
+ assert not self ._args_cache
363
365
self .resolved_type = {}
364
366
365
367
def visit_name_expr (self , e : NameExpr ) -> Type :
@@ -1682,7 +1684,7 @@ def check_call(
1682
1684
def overload_context (self ) -> Iterator [None ]:
1683
1685
self .overload_stack_depth += 1
1684
1686
yield
1685
- self .overload_stack_depth - = 1
1687
+ self .overload_stack_depth + = 1
1686
1688
if self .overload_stack_depth == 0 :
1687
1689
self ._args_cache .clear ()
1688
1690
@@ -1949,7 +1951,12 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
1949
1951
In short, we basically recurse on each argument without considering
1950
1952
in what context the argument was called.
1951
1953
"""
1952
- can_cache = not any (isinstance (t , TempNode ) for t in args )
1954
+ # We can only use this hack locally while checking a single nested overloaded
1955
+ # call. This saves a lot of rechecking, but is not generally safe. Cache is
1956
+ # pruned upon leaving the outermost overload.
1957
+ can_cache = self .overload_stack_depth > 0 and not any (
1958
+ isinstance (t , TempNode ) for t in args
1959
+ )
1953
1960
key = tuple (map (id , args ))
1954
1961
if can_cache and key in self ._args_cache :
1955
1962
return self ._args_cache [key ]
0 commit comments