@@ -356,6 +356,9 @@ def __init__(
356
356
357
357
self ._arg_infer_context_cache = None
358
358
359
+ self .overload_stack_depth = 0
360
+ self ._args_cache = {}
361
+
359
362
def reset (self ) -> None :
360
363
self .resolved_type = {}
361
364
@@ -1613,9 +1616,10 @@ def check_call(
1613
1616
object_type ,
1614
1617
)
1615
1618
elif isinstance (callee , Overloaded ):
1616
- return self .check_overload_call (
1617
- callee , args , arg_kinds , arg_names , callable_name , object_type , context
1618
- )
1619
+ with self .overload_context (callee .name ()):
1620
+ return self .check_overload_call (
1621
+ callee , args , arg_kinds , arg_names , callable_name , object_type , context
1622
+ )
1619
1623
elif isinstance (callee , AnyType ) or not self .chk .in_checked_function ():
1620
1624
return self .check_any_type_call (args , callee )
1621
1625
elif isinstance (callee , UnionType ):
@@ -1674,6 +1678,14 @@ def check_call(
1674
1678
else :
1675
1679
return self .msg .not_callable (callee , context ), AnyType (TypeOfAny .from_error )
1676
1680
1681
+ @contextmanager
1682
+ def overload_context (self , fn ):
1683
+ self .overload_stack_depth += 1
1684
+ yield
1685
+ self .overload_stack_depth -= 1
1686
+ if self .overload_stack_depth == 0 :
1687
+ self ._args_cache .clear ()
1688
+
1677
1689
def check_callable_call (
1678
1690
self ,
1679
1691
callee : CallableType ,
@@ -1937,6 +1949,10 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
1937
1949
In short, we basically recurse on each argument without considering
1938
1950
in what context the argument was called.
1939
1951
"""
1952
+ can_cache = not any (isinstance (t , TempNode ) for t in args )
1953
+ key = tuple (map (id , args ))
1954
+ if can_cache and key in self ._args_cache :
1955
+ return self ._args_cache [key ]
1940
1956
res : list [Type ] = []
1941
1957
1942
1958
for arg in args :
@@ -1945,6 +1961,8 @@ def infer_arg_types_in_empty_context(self, args: list[Expression]) -> list[Type]
1945
1961
res .append (NoneType ())
1946
1962
else :
1947
1963
res .append (arg_type )
1964
+ if can_cache :
1965
+ self ._args_cache [key ] = res
1948
1966
return res
1949
1967
1950
1968
def infer_more_unions_for_recursive_type (self , type_context : Type ) -> bool :
@@ -2917,17 +2935,16 @@ def infer_overload_return_type(
2917
2935
2918
2936
for typ in plausible_targets :
2919
2937
assert self .msg is self .chk .msg
2920
- with self .msg .filter_errors () as w :
2921
- with self .chk .local_type_map () as m :
2922
- ret_type , infer_type = self .check_call (
2923
- callee = typ ,
2924
- args = args ,
2925
- arg_kinds = arg_kinds ,
2926
- arg_names = arg_names ,
2927
- context = context ,
2928
- callable_name = callable_name ,
2929
- object_type = object_type ,
2930
- )
2938
+ with self .msg .filter_errors () as w , self .chk .local_type_map () as m :
2939
+ ret_type , infer_type = self .check_call (
2940
+ callee = typ ,
2941
+ args = args ,
2942
+ arg_kinds = arg_kinds ,
2943
+ arg_names = arg_names ,
2944
+ context = context ,
2945
+ callable_name = callable_name ,
2946
+ object_type = object_type ,
2947
+ )
2931
2948
is_match = not w .has_new_errors ()
2932
2949
if is_match :
2933
2950
# Return early if possible; otherwise record info, so we can
@@ -3474,6 +3491,10 @@ def visit_op_expr(self, e: OpExpr) -> Type:
3474
3491
return self .strfrm_checker .check_str_interpolation (e .left , e .right )
3475
3492
if isinstance (e .left , StrExpr ):
3476
3493
return self .strfrm_checker .check_str_interpolation (e .left , e .right )
3494
+
3495
+ key = id (e )
3496
+ if key in self ._args_cache :
3497
+ return self ._args_cache [key ]
3477
3498
left_type = self .accept (e .left )
3478
3499
3479
3500
proper_left_type = get_proper_type (left_type )
@@ -3543,28 +3564,30 @@ def visit_op_expr(self, e: OpExpr) -> Type:
3543
3564
)
3544
3565
3545
3566
if e .op in operators .op_methods :
3546
- method = operators .op_methods [e .op ]
3547
- if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
3548
- result , method_type = self .check_op (
3549
- method ,
3550
- base_type = left_type ,
3551
- arg = e .right ,
3552
- context = e ,
3553
- allow_reverse = use_reverse is UseReverse .DEFAULT ,
3554
- )
3555
- elif use_reverse is UseReverse .ALWAYS :
3556
- result , method_type = self .check_op (
3557
- # The reverse operator here gives better error messages:
3558
- operators .reverse_op_methods [method ],
3559
- base_type = self .accept (e .right ),
3560
- arg = e .left ,
3561
- context = e ,
3562
- allow_reverse = False ,
3563
- )
3564
- else :
3565
- assert_never (use_reverse )
3566
- e .method_type = method_type
3567
- return result
3567
+ with self .overload_context (e .op ):
3568
+ method = operators .op_methods [e .op ]
3569
+ if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
3570
+ result , method_type = self .check_op (
3571
+ method ,
3572
+ base_type = left_type ,
3573
+ arg = e .right ,
3574
+ context = e ,
3575
+ allow_reverse = use_reverse is UseReverse .DEFAULT ,
3576
+ )
3577
+ elif use_reverse is UseReverse .ALWAYS :
3578
+ result , method_type = self .check_op (
3579
+ # The reverse operator here gives better error messages:
3580
+ operators .reverse_op_methods [method ],
3581
+ base_type = self .accept (e .right ),
3582
+ arg = e .left ,
3583
+ context = e ,
3584
+ allow_reverse = False ,
3585
+ )
3586
+ else :
3587
+ assert_never (use_reverse )
3588
+ e .method_type = method_type
3589
+ self ._args_cache [key ] = result
3590
+ return result
3568
3591
else :
3569
3592
raise RuntimeError (f"Unknown operator { e .op } " )
3570
3593
0 commit comments