Skip to content

Commit 4b3722f

Browse files
authored
Improve the signatures of expand_type and expand_type_by_instance (#14879)
By adding another overload, `CallableType -> CallableType`, we can avoid the need for several `cast`s across the code base.
1 parent 106d57e commit 4b3722f

File tree

5 files changed

+19
-9
lines changed

5 files changed

+19
-9
lines changed

mypy/checker.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1753,8 +1753,7 @@ def expand_typevars(
17531753
result: list[tuple[FuncItem, CallableType]] = []
17541754
for substitutions in itertools.product(*subst):
17551755
mapping = dict(substitutions)
1756-
expanded = cast(CallableType, expand_type(typ, mapping))
1757-
result.append((expand_func(defn, mapping), expanded))
1756+
result.append((expand_func(defn, mapping), expand_type(typ, mapping)))
17581757
return result
17591758
else:
17601759
return [(defn, typ)]
@@ -7111,7 +7110,6 @@ def overload_can_never_match(signature: CallableType, other: CallableType) -> bo
71117110
exp_signature = expand_type(
71127111
signature, {tvar.id: erase_def_to_union_or_bound(tvar) for tvar in signature.variables}
71137112
)
7114-
assert isinstance(exp_signature, CallableType)
71157113
return is_callable_compatible(
71167114
exp_signature, other, is_compat=is_more_precise, ignore_return=True
71177115
)

mypy/checkexpr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5518,7 +5518,7 @@ def merge_typevars_in_callables_by_name(
55185518
variables.append(tv)
55195519
rename[tv.id] = unique_typevars[name]
55205520

5521-
target = cast(CallableType, expand_type(target, rename))
5521+
target = expand_type(target, rename)
55225522
output.append(target)
55235523

55245524
return output, variables

mypy/checkmember.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1150,7 +1150,7 @@ class B(A[str]): pass
11501150
t = freshen_all_functions_type_vars(t)
11511151
t = bind_self(t, original_type, is_classmethod=True)
11521152
assert isuper is not None
1153-
t = cast(CallableType, expand_type_by_instance(t, isuper))
1153+
t = expand_type_by_instance(t, isuper)
11541154
freeze_all_type_vars(t)
11551155
return t.copy_modified(variables=list(tvars) + list(t.variables))
11561156
elif isinstance(t, Overloaded):

mypy/expandtype.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@
4747
)
4848

4949

50+
@overload
51+
def expand_type(
52+
typ: CallableType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
53+
) -> CallableType:
54+
...
55+
56+
5057
@overload
5158
def expand_type(
5259
typ: ProperType, env: Mapping[TypeVarId, Type], allow_erased_callables: bool = ...
@@ -70,6 +77,11 @@ def expand_type(
7077
return typ.accept(ExpandTypeVisitor(env, allow_erased_callables))
7178

7279

80+
@overload
81+
def expand_type_by_instance(typ: CallableType, instance: Instance) -> CallableType:
82+
...
83+
84+
7385
@overload
7486
def expand_type_by_instance(typ: ProperType, instance: Instance) -> ProperType:
7587
...
@@ -133,7 +145,7 @@ def freshen_function_type_vars(callee: F) -> F:
133145
tv = ParamSpecType.new_unification_variable(v)
134146
tvs.append(tv)
135147
tvmap[v.id] = tv
136-
fresh = cast(CallableType, expand_type(callee, tvmap)).copy_modified(variables=tvs)
148+
fresh = expand_type(callee, tvmap).copy_modified(variables=tvs)
137149
return cast(F, fresh)
138150
else:
139151
assert isinstance(callee, Overloaded)
@@ -346,7 +358,7 @@ def interpolate_args_for_unpack(
346358
)
347359
return (arg_names, arg_kinds, arg_types)
348360

349-
def visit_callable_type(self, t: CallableType) -> Type:
361+
def visit_callable_type(self, t: CallableType) -> CallableType:
350362
param_spec = t.param_spec()
351363
if param_spec is not None:
352364
repl = get_proper_type(self.variables.get(param_spec.id))

mypy/server/astdiff.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ class level -- these are handled at attribute level (say, 'mod.Cls.method'
5252

5353
from __future__ import annotations
5454

55-
from typing import Sequence, Tuple, Union, cast
55+
from typing import Sequence, Tuple, Union
5656
from typing_extensions import TypeAlias as _TypeAlias
5757

5858
from mypy.expandtype import expand_type
@@ -442,7 +442,7 @@ def normalize_callable_variables(self, typ: CallableType) -> CallableType:
442442
tv = v.copy_modified(id=tid)
443443
tvs.append(tv)
444444
tvmap[v.id] = tv
445-
return cast(CallableType, expand_type(typ, tvmap)).copy_modified(variables=tvs)
445+
return expand_type(typ, tvmap).copy_modified(variables=tvs)
446446

447447
def visit_tuple_type(self, typ: TupleType) -> SnapshotItem:
448448
return ("TupleType", snapshot_types(typ.items))

0 commit comments

Comments
 (0)