Skip to content

Commit fddd5c5

Browse files
authored
Use asserts instead of casts where possible (#14860)
There are many places in mypy's code where `cast`s are currently used unnecessarily. These can be replaced with `assert`s, which are much more type-safe, and more mypyc-friendly.
1 parent 4b3722f commit fddd5c5

File tree

11 files changed

+45
-27
lines changed

11 files changed

+45
-27
lines changed

mypy/api.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
import sys
4949
from io import StringIO
50-
from typing import Callable, TextIO, cast
50+
from typing import Callable, TextIO
5151

5252

5353
def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]:
@@ -59,7 +59,8 @@ def _run(main_wrapper: Callable[[TextIO, TextIO], None]) -> tuple[str, str, int]
5959
main_wrapper(stdout, stderr)
6060
exit_status = 0
6161
except SystemExit as system_exit:
62-
exit_status = cast(int, system_exit.code)
62+
assert isinstance(system_exit.code, int)
63+
exit_status = system_exit.code
6364

6465
return stdout.getvalue(), stderr.getvalue(), exit_status
6566

mypy/checker.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,8 @@ def _visit_overloaded_func_def(self, defn: OverloadedFuncDef) -> None:
629629

630630
if defn.is_property:
631631
# HACK: Infer the type of the property.
632-
self.visit_decorator(cast(Decorator, defn.items[0]))
632+
assert isinstance(defn.items[0], Decorator)
633+
self.visit_decorator(defn.items[0])
633634
for fdef in defn.items:
634635
assert isinstance(fdef, Decorator)
635636
self.check_func_item(fdef.func, name=fdef.func.name, allow_empty=True)
@@ -2575,9 +2576,8 @@ def check_import(self, node: ImportBase) -> None:
25752576
if lvalue_type is None:
25762577
# TODO: This is broken.
25772578
lvalue_type = AnyType(TypeOfAny.special_form)
2578-
message = message_registry.INCOMPATIBLE_IMPORT_OF.format(
2579-
cast(NameExpr, assign.rvalue).name
2580-
)
2579+
assert isinstance(assign.rvalue, NameExpr)
2580+
message = message_registry.INCOMPATIBLE_IMPORT_OF.format(assign.rvalue.name)
25812581
self.check_simple_assignment(
25822582
lvalue_type,
25832583
assign.rvalue,
@@ -3657,8 +3657,8 @@ def check_lvalue(self, lvalue: Lvalue) -> tuple[Type | None, IndexExpr | None, V
36573657
not isinstance(lvalue, NameExpr) or isinstance(lvalue.node, Var)
36583658
):
36593659
if isinstance(lvalue, NameExpr):
3660-
inferred = cast(Var, lvalue.node)
3661-
assert isinstance(inferred, Var)
3660+
assert isinstance(lvalue.node, Var)
3661+
inferred = lvalue.node
36623662
else:
36633663
assert isinstance(lvalue, MemberExpr)
36643664
self.expr_checker.accept(lvalue.expr)
@@ -4984,7 +4984,8 @@ def intersect_instance_callable(self, typ: Instance, callable_type: CallableType
49844984
# In order for this to work in incremental mode, the type we generate needs to
49854985
# have a valid fullname and a corresponding entry in a symbol table. We generate
49864986
# a unique name inside the symbol table of the current module.
4987-
cur_module = cast(MypyFile, self.scope.stack[0])
4987+
cur_module = self.scope.stack[0]
4988+
assert isinstance(cur_module, MypyFile)
49884989
gen_name = gen_unique_name(f"<callable subtype of {typ.type.name}>", cur_module.names)
49894990

49904991
# Synthesize a fake TypeInfo
@@ -6196,7 +6197,8 @@ def lookup(self, name: str) -> SymbolTableNode:
61966197
else:
61976198
b = self.globals.get("__builtins__", None)
61986199
if b:
6199-
table = cast(MypyFile, b.node).names
6200+
assert isinstance(b.node, MypyFile)
6201+
table = b.node.names
62006202
if name in table:
62016203
return table[name]
62026204
raise KeyError(f"Failed lookup: {name}")
@@ -6210,7 +6212,8 @@ def lookup_qualified(self, name: str) -> SymbolTableNode:
62106212
for i in range(1, len(parts) - 1):
62116213
sym = n.names.get(parts[i])
62126214
assert sym is not None, "Internal error: attempted lookup of unknown name"
6213-
n = cast(MypyFile, sym.node)
6215+
assert isinstance(sym.node, MypyFile)
6216+
n = sym.node
62146217
last = parts[-1]
62156218
if last in n.names:
62166219
return n.names[last]
@@ -6514,7 +6517,8 @@ def is_writable_attribute(self, node: Node) -> bool:
65146517
return False
65156518
return True
65166519
elif isinstance(node, OverloadedFuncDef) and node.is_property:
6517-
first_item = cast(Decorator, node.items[0])
6520+
first_item = node.items[0]
6521+
assert isinstance(first_item, Decorator)
65186522
return first_item.var.is_settable_property
65196523
return False
65206524

mypy/checkmember.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,8 @@ def analyze_instance_member_access(
312312

313313
if method.is_property:
314314
assert isinstance(method, OverloadedFuncDef)
315-
first_item = cast(Decorator, method.items[0])
315+
first_item = method.items[0]
316+
assert isinstance(first_item, Decorator)
316317
return analyze_var(name, first_item.var, typ, info, mx)
317318
if mx.is_lvalue:
318319
mx.msg.cant_assign_to_method(mx.context)

mypy/fastparse.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -665,7 +665,9 @@ def fix_function_overloads(self, stmts: list[Statement]) -> list[Statement]:
665665
if current_overload and current_overload_name == last_if_stmt_overload_name:
666666
# Remove last stmt (IfStmt) from ret if the overload names matched
667667
# Only happens if no executable block had been found in IfStmt
668-
skipped_if_stmts.append(cast(IfStmt, ret.pop()))
668+
popped = ret.pop()
669+
assert isinstance(popped, IfStmt)
670+
skipped_if_stmts.append(popped)
669671
if current_overload and skipped_if_stmts:
670672
# Add bare IfStmt (without overloads) to ret
671673
# Required for mypy to be able to still check conditions

mypy/nodes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2183,7 +2183,8 @@ def name(self) -> str:
21832183

21842184
def expr(self) -> Expression:
21852185
"""Return the expression (the body) of the lambda."""
2186-
ret = cast(ReturnStmt, self.body.body[-1])
2186+
ret = self.body.body[-1]
2187+
assert isinstance(ret, ReturnStmt)
21872188
expr = ret.expr
21882189
assert expr is not None # lambda can't have empty body
21892190
return expr

mypy/report.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import tokenize
1313
from abc import ABCMeta, abstractmethod
1414
from operator import attrgetter
15-
from typing import Any, Callable, Dict, Iterator, Tuple, cast
15+
from typing import Any, Callable, Dict, Iterator, Tuple
1616
from typing_extensions import Final, TypeAlias as _TypeAlias
1717
from urllib.request import pathname2url
1818

@@ -704,8 +704,9 @@ def __init__(self, reports: Reports, output_dir: str) -> None:
704704
super().__init__(reports, output_dir)
705705

706706
memory_reporter = reports.add_report("memory-xml", "<memory>")
707+
assert isinstance(memory_reporter, MemoryXmlReporter)
707708
# The dependency will be called first.
708-
self.memory_xml = cast(MemoryXmlReporter, memory_reporter)
709+
self.memory_xml = memory_reporter
709710

710711

711712
class XmlReporter(AbstractXmlReporter):

mypy/semanal.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1314,7 +1314,8 @@ def analyze_property_with_multi_part_definition(self, defn: OverloadedFuncDef) -
13141314
"""
13151315
defn.is_property = True
13161316
items = defn.items
1317-
first_item = cast(Decorator, defn.items[0])
1317+
first_item = defn.items[0]
1318+
assert isinstance(first_item, Decorator)
13181319
deleted_items = []
13191320
for i, item in enumerate(items[1:]):
13201321
if isinstance(item, Decorator):
@@ -1357,7 +1358,8 @@ def analyze_function_body(self, defn: FuncItem) -> None:
13571358
# Bind the type variables again to visit the body.
13581359
if defn.type:
13591360
a = self.type_analyzer()
1360-
typ = cast(CallableType, defn.type)
1361+
typ = defn.type
1362+
assert isinstance(typ, CallableType)
13611363
a.bind_function_type_variables(typ, defn)
13621364
for i in range(len(typ.arg_types)):
13631365
store_argument_type(defn, i, typ, self.named_type)

mypy/server/astmerge.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,8 @@ def fixup_and_reset_typeinfo(self, node: TypeInfo) -> TypeInfo:
358358
if node in self.replacements:
359359
# The subclass relationships may change, so reset all caches relevant to the
360360
# old MRO.
361-
new = cast(TypeInfo, self.replacements[node])
361+
new = self.replacements[node]
362+
assert isinstance(new, TypeInfo)
362363
type_state.reset_all_subtype_caches_for(new)
363364
return self.fixup(node)
364365

mypy/stats.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import os
66
from collections import Counter
77
from contextlib import contextmanager
8-
from typing import Iterator, cast
8+
from typing import Iterator
99
from typing_extensions import Final
1010

1111
from mypy import nodes
@@ -154,10 +154,12 @@ def visit_func_def(self, o: FuncDef) -> None:
154154
)
155155
return
156156
for defn in o.expanded:
157-
self.visit_func_def(cast(FuncDef, defn))
157+
assert isinstance(defn, FuncDef)
158+
self.visit_func_def(defn)
158159
else:
159160
if o.type:
160-
sig = cast(CallableType, o.type)
161+
assert isinstance(o.type, CallableType)
162+
sig = o.type
161163
arg_types = sig.arg_types
162164
if sig.arg_names and sig.arg_names[0] == "self" and not self.inferred:
163165
arg_types = arg_types[1:]

mypy/test/testfinegrained.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import re
1919
import sys
2020
import unittest
21-
from typing import Any, cast
21+
from typing import Any
2222

2323
import pytest
2424

@@ -169,7 +169,8 @@ def get_options(self, source: str, testcase: DataDrivenTestCase, build_cache: bo
169169

170170
def run_check(self, server: Server, sources: list[BuildSource]) -> list[str]:
171171
response = server.check(sources, export_types=True, is_tty=False, terminal_width=-1)
172-
out = cast(str, response["out"] or response["err"])
172+
out = response["out"] or response["err"]
173+
assert isinstance(out, str)
173174
return out.splitlines()
174175

175176
def build(self, options: Options, sources: list[BuildSource]) -> list[str]:

0 commit comments

Comments
 (0)