Skip to content

Commit 03af5b2

Browse files
committed
Fix selfcheck
1 parent e3ea993 commit 03af5b2

File tree

1 file changed

+17
-7
lines changed

1 file changed

+17
-7
lines changed

mypy/checkexpr.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,9 @@ def __init__(
357357
self._arg_infer_context_cache = None
358358

359359
self.overload_stack_depth = 0
360-
self._args_cache = {}
360+
self.ops_stack_depth = 0
361+
self._args_cache: dict[tuple[int, ...], list[Type]] = {}
362+
self._ops_cache: dict[int, Type] = {}
361363

362364
def reset(self) -> None:
363365
self.resolved_type = {}
@@ -1616,7 +1618,7 @@ def check_call(
16161618
object_type,
16171619
)
16181620
elif isinstance(callee, Overloaded):
1619-
with self.overload_context(callee.name()):
1621+
with self.overload_context():
16201622
return self.check_overload_call(
16211623
callee, args, arg_kinds, arg_names, callable_name, object_type, context
16221624
)
@@ -1679,13 +1681,21 @@ def check_call(
16791681
return self.msg.not_callable(callee, context), AnyType(TypeOfAny.from_error)
16801682

16811683
@contextmanager
1682-
def overload_context(self, fn):
1684+
def overload_context(self) -> Iterator[None]:
16831685
self.overload_stack_depth += 1
16841686
yield
16851687
self.overload_stack_depth -= 1
16861688
if self.overload_stack_depth == 0:
16871689
self._args_cache.clear()
16881690

1691+
@contextmanager
1692+
def ops_context(self) -> Iterator[None]:
1693+
self.ops_stack_depth += 1
1694+
yield
1695+
self.ops_stack_depth -= 1
1696+
if self.ops_stack_depth == 0:
1697+
self._ops_cache.clear()
1698+
16891699
def check_callable_call(
16901700
self,
16911701
callee: CallableType,
@@ -3493,8 +3503,8 @@ def visit_op_expr(self, e: OpExpr) -> Type:
34933503
return self.strfrm_checker.check_str_interpolation(e.left, e.right)
34943504

34953505
key = id(e)
3496-
if key in self._args_cache:
3497-
return self._args_cache[key]
3506+
if key in self._ops_cache:
3507+
return self._ops_cache[key]
34983508
left_type = self.accept(e.left)
34993509

35003510
proper_left_type = get_proper_type(left_type)
@@ -3564,7 +3574,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35643574
)
35653575

35663576
if e.op in operators.op_methods:
3567-
with self.overload_context(e.op):
3577+
with self.ops_context():
35683578
method = operators.op_methods[e.op]
35693579
if use_reverse is UseReverse.DEFAULT or use_reverse is UseReverse.NEVER:
35703580
result, method_type = self.check_op(
@@ -3586,7 +3596,7 @@ def visit_op_expr(self, e: OpExpr) -> Type:
35863596
else:
35873597
assert_never(use_reverse)
35883598
e.method_type = method_type
3589-
self._args_cache[key] = result
3599+
self._ops_cache[key] = result
35903600
return result
35913601
else:
35923602
raise RuntimeError(f"Unknown operator {e.op}")

0 commit comments

Comments
 (0)