@@ -357,7 +357,9 @@ def __init__(
357
357
self ._arg_infer_context_cache = None
358
358
359
359
self .overload_stack_depth = 0
360
- self ._args_cache = {}
360
+ self .ops_stack_depth = 0
361
+ self ._args_cache : dict [tuple [int , ...], list [Type ]] = {}
362
+ self ._ops_cache : dict [int , Type ] = {}
361
363
362
364
def reset (self ) -> None :
363
365
self .resolved_type = {}
@@ -1616,7 +1618,7 @@ def check_call(
1616
1618
object_type ,
1617
1619
)
1618
1620
elif isinstance (callee , Overloaded ):
1619
- with self .overload_context (callee . name () ):
1621
+ with self .overload_context ():
1620
1622
return self .check_overload_call (
1621
1623
callee , args , arg_kinds , arg_names , callable_name , object_type , context
1622
1624
)
@@ -1679,13 +1681,21 @@ def check_call(
1679
1681
return self .msg .not_callable (callee , context ), AnyType (TypeOfAny .from_error )
1680
1682
1681
1683
@contextmanager
1682
- def overload_context (self , fn ) :
1684
+ def overload_context (self ) -> Iterator [ None ] :
1683
1685
self .overload_stack_depth += 1
1684
1686
yield
1685
1687
self .overload_stack_depth -= 1
1686
1688
if self .overload_stack_depth == 0 :
1687
1689
self ._args_cache .clear ()
1688
1690
1691
+ @contextmanager
1692
+ def ops_context (self ) -> Iterator [None ]:
1693
+ self .ops_stack_depth += 1
1694
+ yield
1695
+ self .ops_stack_depth -= 1
1696
+ if self .ops_stack_depth == 0 :
1697
+ self ._ops_cache .clear ()
1698
+
1689
1699
def check_callable_call (
1690
1700
self ,
1691
1701
callee : CallableType ,
@@ -3493,8 +3503,8 @@ def visit_op_expr(self, e: OpExpr) -> Type:
3493
3503
return self .strfrm_checker .check_str_interpolation (e .left , e .right )
3494
3504
3495
3505
key = id (e )
3496
- if key in self ._args_cache :
3497
- return self ._args_cache [key ]
3506
+ if key in self ._ops_cache :
3507
+ return self ._ops_cache [key ]
3498
3508
left_type = self .accept (e .left )
3499
3509
3500
3510
proper_left_type = get_proper_type (left_type )
@@ -3564,7 +3574,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
3564
3574
)
3565
3575
3566
3576
if e .op in operators .op_methods :
3567
- with self .overload_context ( e . op ):
3577
+ with self .ops_context ( ):
3568
3578
method = operators .op_methods [e .op ]
3569
3579
if use_reverse is UseReverse .DEFAULT or use_reverse is UseReverse .NEVER :
3570
3580
result , method_type = self .check_op (
@@ -3586,7 +3596,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
3586
3596
else :
3587
3597
assert_never (use_reverse )
3588
3598
e .method_type = method_type
3589
- self ._args_cache [key ] = result
3599
+ self ._ops_cache [key ] = result
3590
3600
return result
3591
3601
else :
3592
3602
raise RuntimeError (f"Unknown operator { e .op } " )
0 commit comments