Skip to content

Commit 3518f24

Browse files
authored
Move apply_type() to applytype.py (#17346)
Moving towards #15907 This is a pure refactoring. It was surprisingly easy, this didn't add new import cycles, because there is already (somewhat fundamental) cycle `applytype.py` <-> `subtypes.py`.
1 parent 8dd268f commit 3518f24

File tree

2 files changed

+134
-129
lines changed

2 files changed

+134
-129
lines changed

mypy/applytype.py

Lines changed: 133 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,32 @@
11
from __future__ import annotations
22

3-
from typing import Callable, Sequence
3+
from typing import Callable, Iterable, Sequence
44

55
import mypy.subtypes
66
from mypy.erasetype import erase_typevars
77
from mypy.expandtype import expand_type
8-
from mypy.nodes import Context
8+
from mypy.nodes import Context, TypeInfo
9+
from mypy.type_visitor import TypeTranslator
10+
from mypy.typeops import get_all_type_vars
911
from mypy.types import (
1012
AnyType,
1113
CallableType,
14+
Instance,
15+
Parameters,
16+
ParamSpecFlavor,
1217
ParamSpecType,
1318
PartialType,
19+
ProperType,
1420
Type,
21+
TypeAliasType,
1522
TypeVarId,
1623
TypeVarLikeType,
1724
TypeVarTupleType,
1825
TypeVarType,
1926
UninhabitedType,
2027
UnpackType,
2128
get_proper_type,
29+
remove_dups,
2230
)
2331

2432

@@ -170,3 +178,126 @@ def apply_generic_arguments(
170178
type_guard=type_guard,
171179
type_is=type_is,
172180
)
181+
182+
183+
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
184+
"""Make free type variables generic in the type if possible.
185+
186+
This will translate the type `tp` while trying to create valid bindings for
187+
type variables `poly_tvars` while traversing the type. This follows the same rules
188+
as we do during semantic analysis phase, examples:
189+
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
190+
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
191+
* List[T] -> None (not possible)
192+
"""
193+
try:
194+
return tp.copy_modified(
195+
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
196+
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
197+
variables=[],
198+
)
199+
except PolyTranslationError:
200+
return None
201+
202+
203+
class PolyTranslationError(Exception):
204+
pass
205+
206+
207+
class PolyTranslator(TypeTranslator):
208+
"""Make free type variables generic in the type if possible.
209+
210+
See docstring for apply_poly() for details.
211+
"""
212+
213+
def __init__(
214+
self,
215+
poly_tvars: Iterable[TypeVarLikeType],
216+
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
217+
seen_aliases: frozenset[TypeInfo] = frozenset(),
218+
) -> None:
219+
self.poly_tvars = set(poly_tvars)
220+
# This is a simplified version of TypeVarScope used during semantic analysis.
221+
self.bound_tvars = bound_tvars
222+
self.seen_aliases = seen_aliases
223+
224+
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
225+
found_vars = []
226+
for arg in t.arg_types:
227+
for tv in get_all_type_vars(arg):
228+
if isinstance(tv, ParamSpecType):
229+
normalized: TypeVarLikeType = tv.copy_modified(
230+
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
231+
)
232+
else:
233+
normalized = tv
234+
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
235+
found_vars.append(normalized)
236+
return remove_dups(found_vars)
237+
238+
def visit_callable_type(self, t: CallableType) -> Type:
239+
found_vars = self.collect_vars(t)
240+
self.bound_tvars |= set(found_vars)
241+
result = super().visit_callable_type(t)
242+
self.bound_tvars -= set(found_vars)
243+
244+
assert isinstance(result, ProperType) and isinstance(result, CallableType)
245+
result.variables = list(result.variables) + found_vars
246+
return result
247+
248+
def visit_type_var(self, t: TypeVarType) -> Type:
249+
if t in self.poly_tvars and t not in self.bound_tvars:
250+
raise PolyTranslationError()
251+
return super().visit_type_var(t)
252+
253+
def visit_param_spec(self, t: ParamSpecType) -> Type:
254+
if t in self.poly_tvars and t not in self.bound_tvars:
255+
raise PolyTranslationError()
256+
return super().visit_param_spec(t)
257+
258+
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
259+
if t in self.poly_tvars and t not in self.bound_tvars:
260+
raise PolyTranslationError()
261+
return super().visit_type_var_tuple(t)
262+
263+
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
264+
if not t.args:
265+
return t.copy_modified()
266+
if not t.is_recursive:
267+
return get_proper_type(t).accept(self)
268+
# We can't handle polymorphic application for recursive generic aliases
269+
# without risking an infinite recursion, just give up for now.
270+
raise PolyTranslationError()
271+
272+
def visit_instance(self, t: Instance) -> Type:
273+
if t.type.has_param_spec_type:
274+
# We need this special-casing to preserve the possibility to store a
275+
# generic function in an instance type. Things like
276+
# forall T . Foo[[x: T], T]
277+
# are not really expressible in current type system, but this looks like
278+
# a useful feature, so let's keep it.
279+
param_spec_index = next(
280+
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
281+
)
282+
p = get_proper_type(t.args[param_spec_index])
283+
if isinstance(p, Parameters):
284+
found_vars = self.collect_vars(p)
285+
self.bound_tvars |= set(found_vars)
286+
new_args = [a.accept(self) for a in t.args]
287+
self.bound_tvars -= set(found_vars)
288+
289+
repl = new_args[param_spec_index]
290+
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
291+
repl.variables = list(repl.variables) + list(found_vars)
292+
return t.copy_modified(args=new_args)
293+
# There is the same problem with callback protocols as with aliases
294+
# (callback protocols are essentially more flexible aliases to callables).
295+
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
296+
if t.type in self.seen_aliases:
297+
raise PolyTranslationError()
298+
call = mypy.subtypes.find_member("__call__", t, t, is_operator=True)
299+
assert call is not None
300+
return call.accept(
301+
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
302+
)
303+
return super().visit_instance(t)

mypy/checkexpr.py

Lines changed: 1 addition & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,6 @@
115115
non_method_protocol_members,
116116
)
117117
from mypy.traverser import has_await_expression
118-
from mypy.type_visitor import TypeTranslator
119118
from mypy.typeanal import (
120119
check_for_explicit_any,
121120
fix_instance,
@@ -168,7 +167,6 @@
168167
TypeOfAny,
169168
TypeType,
170169
TypeVarId,
171-
TypeVarLikeType,
172170
TypeVarTupleType,
173171
TypeVarType,
174172
UnboundType,
@@ -182,7 +180,6 @@
182180
get_proper_types,
183181
has_recursive_types,
184182
is_named_instance,
185-
remove_dups,
186183
split_with_prefix_and_suffix,
187184
)
188185
from mypy.types_utils import (
@@ -2136,7 +2133,7 @@ def infer_function_type_arguments(
21362133
)
21372134
# Try applying inferred polymorphic type if possible, e.g. Callable[[T], T] can
21382135
# be interpreted as def [T] (T) -> T, but dict[T, T] cannot be expressed.
2139-
applied = apply_poly(poly_callee_type, free_vars)
2136+
applied = applytype.apply_poly(poly_callee_type, free_vars)
21402137
if applied is not None and all(
21412138
a is not None and not isinstance(get_proper_type(a), UninhabitedType)
21422139
for a in poly_inferred_args
@@ -6220,129 +6217,6 @@ def replace_callable_return_type(c: CallableType, new_ret_type: Type) -> Callabl
62206217
return c.copy_modified(ret_type=new_ret_type)
62216218

62226219

6223-
def apply_poly(tp: CallableType, poly_tvars: Sequence[TypeVarLikeType]) -> CallableType | None:
6224-
"""Make free type variables generic in the type if possible.
6225-
6226-
This will translate the type `tp` while trying to create valid bindings for
6227-
type variables `poly_tvars` while traversing the type. This follows the same rules
6228-
as we do during semantic analysis phase, examples:
6229-
* Callable[Callable[[T], T], T] -> def [T] (def (T) -> T) -> T
6230-
* Callable[[], Callable[[T], T]] -> def () -> def [T] (T -> T)
6231-
* List[T] -> None (not possible)
6232-
"""
6233-
try:
6234-
return tp.copy_modified(
6235-
arg_types=[t.accept(PolyTranslator(poly_tvars)) for t in tp.arg_types],
6236-
ret_type=tp.ret_type.accept(PolyTranslator(poly_tvars)),
6237-
variables=[],
6238-
)
6239-
except PolyTranslationError:
6240-
return None
6241-
6242-
6243-
class PolyTranslationError(Exception):
6244-
pass
6245-
6246-
6247-
class PolyTranslator(TypeTranslator):
6248-
"""Make free type variables generic in the type if possible.
6249-
6250-
See docstring for apply_poly() for details.
6251-
"""
6252-
6253-
def __init__(
6254-
self,
6255-
poly_tvars: Iterable[TypeVarLikeType],
6256-
bound_tvars: frozenset[TypeVarLikeType] = frozenset(),
6257-
seen_aliases: frozenset[TypeInfo] = frozenset(),
6258-
) -> None:
6259-
self.poly_tvars = set(poly_tvars)
6260-
# This is a simplified version of TypeVarScope used during semantic analysis.
6261-
self.bound_tvars = bound_tvars
6262-
self.seen_aliases = seen_aliases
6263-
6264-
def collect_vars(self, t: CallableType | Parameters) -> list[TypeVarLikeType]:
6265-
found_vars = []
6266-
for arg in t.arg_types:
6267-
for tv in get_all_type_vars(arg):
6268-
if isinstance(tv, ParamSpecType):
6269-
normalized: TypeVarLikeType = tv.copy_modified(
6270-
flavor=ParamSpecFlavor.BARE, prefix=Parameters([], [], [])
6271-
)
6272-
else:
6273-
normalized = tv
6274-
if normalized in self.poly_tvars and normalized not in self.bound_tvars:
6275-
found_vars.append(normalized)
6276-
return remove_dups(found_vars)
6277-
6278-
def visit_callable_type(self, t: CallableType) -> Type:
6279-
found_vars = self.collect_vars(t)
6280-
self.bound_tvars |= set(found_vars)
6281-
result = super().visit_callable_type(t)
6282-
self.bound_tvars -= set(found_vars)
6283-
6284-
assert isinstance(result, ProperType) and isinstance(result, CallableType)
6285-
result.variables = list(result.variables) + found_vars
6286-
return result
6287-
6288-
def visit_type_var(self, t: TypeVarType) -> Type:
6289-
if t in self.poly_tvars and t not in self.bound_tvars:
6290-
raise PolyTranslationError()
6291-
return super().visit_type_var(t)
6292-
6293-
def visit_param_spec(self, t: ParamSpecType) -> Type:
6294-
if t in self.poly_tvars and t not in self.bound_tvars:
6295-
raise PolyTranslationError()
6296-
return super().visit_param_spec(t)
6297-
6298-
def visit_type_var_tuple(self, t: TypeVarTupleType) -> Type:
6299-
if t in self.poly_tvars and t not in self.bound_tvars:
6300-
raise PolyTranslationError()
6301-
return super().visit_type_var_tuple(t)
6302-
6303-
def visit_type_alias_type(self, t: TypeAliasType) -> Type:
6304-
if not t.args:
6305-
return t.copy_modified()
6306-
if not t.is_recursive:
6307-
return get_proper_type(t).accept(self)
6308-
# We can't handle polymorphic application for recursive generic aliases
6309-
# without risking an infinite recursion, just give up for now.
6310-
raise PolyTranslationError()
6311-
6312-
def visit_instance(self, t: Instance) -> Type:
6313-
if t.type.has_param_spec_type:
6314-
# We need this special-casing to preserve the possibility to store a
6315-
# generic function in an instance type. Things like
6316-
# forall T . Foo[[x: T], T]
6317-
# are not really expressible in current type system, but this looks like
6318-
# a useful feature, so let's keep it.
6319-
param_spec_index = next(
6320-
i for (i, tv) in enumerate(t.type.defn.type_vars) if isinstance(tv, ParamSpecType)
6321-
)
6322-
p = get_proper_type(t.args[param_spec_index])
6323-
if isinstance(p, Parameters):
6324-
found_vars = self.collect_vars(p)
6325-
self.bound_tvars |= set(found_vars)
6326-
new_args = [a.accept(self) for a in t.args]
6327-
self.bound_tvars -= set(found_vars)
6328-
6329-
repl = new_args[param_spec_index]
6330-
assert isinstance(repl, ProperType) and isinstance(repl, Parameters)
6331-
repl.variables = list(repl.variables) + list(found_vars)
6332-
return t.copy_modified(args=new_args)
6333-
# There is the same problem with callback protocols as with aliases
6334-
# (callback protocols are essentially more flexible aliases to callables).
6335-
if t.args and t.type.is_protocol and t.type.protocol_members == ["__call__"]:
6336-
if t.type in self.seen_aliases:
6337-
raise PolyTranslationError()
6338-
call = find_member("__call__", t, t, is_operator=True)
6339-
assert call is not None
6340-
return call.accept(
6341-
PolyTranslator(self.poly_tvars, self.bound_tvars, self.seen_aliases | {t.type})
6342-
)
6343-
return super().visit_instance(t)
6344-
6345-
63466220
class ArgInferSecondPassQuery(types.BoolTypeQuery):
63476221
"""Query whether an argument type should be inferred in the second pass.
63486222

0 commit comments

Comments
 (0)