Skip to content

Commit dbe9a88

Browse files
authored
Make join of recursive types more robust (#13808)
Fixes #13795 Calculating tuple fallbacks on the fly creates a cycle between joins and subtyping. Although IMO this is conceptually not a right thing, it is hard to get rid of (unless we want to use unions in the fallbacks, cc @JukkaL). So instead I re-worked how `join_types()` works w.r.t. `get_proper_type()`, essentially it now follows the golden rule "Always pass on the original type if possible".
1 parent 9f39120 commit dbe9a88

File tree

6 files changed

+66
-25
lines changed

6 files changed

+66
-25
lines changed

mypy/checkexpr.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,7 @@
141141
StarType,
142142
TupleType,
143143
Type,
144+
TypeAliasType,
144145
TypedDictType,
145146
TypeOfAny,
146147
TypeType,
@@ -195,10 +196,12 @@ class TooManyUnions(Exception):
195196
"""
196197

197198

198-
def allow_fast_container_literal(t: ProperType) -> bool:
199+
def allow_fast_container_literal(t: Type) -> bool:
200+
if isinstance(t, TypeAliasType) and t.is_recursive:
201+
return False
202+
t = get_proper_type(t)
199203
return isinstance(t, Instance) or (
200-
isinstance(t, TupleType)
201-
and all(allow_fast_container_literal(get_proper_type(it)) for it in t.items)
204+
isinstance(t, TupleType) and all(allow_fast_container_literal(it) for it in t.items)
202205
)
203206

204207

@@ -4603,7 +4606,7 @@ def visit_conditional_expr(self, e: ConditionalExpr, allow_none_return: bool = F
46034606
#
46044607
# TODO: Always create a union or at least in more cases?
46054608
if isinstance(get_proper_type(self.type_context[-1]), UnionType):
4606-
res = make_simplified_union([if_type, full_context_else_type])
4609+
res: Type = make_simplified_union([if_type, full_context_else_type])
46074610
else:
46084611
res = join.join_types(if_type, else_type)
46094612

mypy/join.py

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
from __future__ import annotations
44

5+
from typing import overload
6+
57
import mypy.typeops
68
from mypy.maptype import map_instance_to_supertype
79
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT
@@ -131,7 +133,6 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
131133
best = res
132134
assert best is not None
133135
for promote in t.type._promote:
134-
promote = get_proper_type(promote)
135136
if isinstance(promote, Instance):
136137
res = self.join_instances(promote, s)
137138
if is_better(res, best):
@@ -182,17 +183,29 @@ def join_simple(declaration: Type | None, s: Type, t: Type) -> ProperType:
182183
return declaration
183184

184185

185-
def trivial_join(s: Type, t: Type) -> ProperType:
186+
def trivial_join(s: Type, t: Type) -> Type:
186187
"""Return one of types (expanded) if it is a supertype of other, otherwise top type."""
187188
if is_subtype(s, t):
188-
return get_proper_type(t)
189+
return t
189190
elif is_subtype(t, s):
190-
return get_proper_type(s)
191+
return s
191192
else:
192193
return object_or_any_from_type(get_proper_type(t))
193194

194195

195-
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> ProperType:
196+
@overload
197+
def join_types(
198+
s: ProperType, t: ProperType, instance_joiner: InstanceJoiner | None = None
199+
) -> ProperType:
200+
...
201+
202+
203+
@overload
204+
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
205+
...
206+
207+
208+
def join_types(s: Type, t: Type, instance_joiner: InstanceJoiner | None = None) -> Type:
196209
"""Return the least upper bound of s and t.
197210
198211
For example, the join of 'int' and 'object' is 'object'.
@@ -443,7 +456,7 @@ def visit_tuple_type(self, t: TupleType) -> ProperType:
443456
if self.s.length() == t.length():
444457
items: list[Type] = []
445458
for i in range(t.length()):
446-
items.append(self.join(t.items[i], self.s.items[i]))
459+
items.append(join_types(t.items[i], self.s.items[i]))
447460
return TupleType(items, fallback)
448461
else:
449462
return fallback
@@ -487,7 +500,7 @@ def visit_partial_type(self, t: PartialType) -> ProperType:
487500

488501
def visit_type_type(self, t: TypeType) -> ProperType:
489502
if isinstance(self.s, TypeType):
490-
return TypeType.make_normalized(self.join(t.item, self.s.item), line=t.line)
503+
return TypeType.make_normalized(join_types(t.item, self.s.item), line=t.line)
491504
elif isinstance(self.s, Instance) and self.s.type.fullname == "builtins.type":
492505
return self.s
493506
else:
@@ -496,9 +509,6 @@ def visit_type_type(self, t: TypeType) -> ProperType:
496509
def visit_type_alias_type(self, t: TypeAliasType) -> ProperType:
497510
assert False, f"This should be never called, got {t}"
498511

499-
def join(self, s: Type, t: Type) -> ProperType:
500-
return join_types(s, t)
501-
502512
def default(self, typ: Type) -> ProperType:
503513
typ = get_proper_type(typ)
504514
if isinstance(typ, Instance):
@@ -654,19 +664,19 @@ def object_or_any_from_type(typ: ProperType) -> ProperType:
654664
return AnyType(TypeOfAny.implementation_artifact)
655665

656666

657-
def join_type_list(types: list[Type]) -> ProperType:
667+
def join_type_list(types: list[Type]) -> Type:
658668
if not types:
659669
# This is a little arbitrary but reasonable. Any empty tuple should be compatible
660670
# with all variable length tuples, and this makes it possible.
661671
return UninhabitedType()
662-
joined = get_proper_type(types[0])
672+
joined = types[0]
663673
for t in types[1:]:
664674
joined = join_types(joined, t)
665675
return joined
666676

667677

668-
def unpack_callback_protocol(t: Instance) -> Type | None:
678+
def unpack_callback_protocol(t: Instance) -> ProperType | None:
669679
assert t.type.is_protocol
670680
if t.type.protocol_members == ["__call__"]:
671-
return find_member("__call__", t, t, is_operator=True)
681+
return get_proper_type(find_member("__call__", t, t, is_operator=True))
672682
return None

mypy/nodes.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2542,9 +2542,9 @@ class PromoteExpr(Expression):
25422542

25432543
__slots__ = ("type",)
25442544

2545-
type: mypy.types.Type
2545+
type: mypy.types.ProperType
25462546

2547-
def __init__(self, type: mypy.types.Type) -> None:
2547+
def __init__(self, type: mypy.types.ProperType) -> None:
25482548
super().__init__()
25492549
self.type = type
25502550

@@ -2769,7 +2769,7 @@ class is generic then it will be a type constructor of higher kind.
27692769
# even though it's not a subclass in Python. The non-standard
27702770
# `@_promote` decorator introduces this, and there are also
27712771
# several builtin examples, in particular `int` -> `float`.
2772-
_promote: list[mypy.types.Type]
2772+
_promote: list[mypy.types.ProperType]
27732773

27742774
# This is used for promoting native integer types such as 'i64' to
27752775
# 'int'. (_promote is used for the other direction.) This only
@@ -3100,7 +3100,12 @@ def deserialize(cls, data: JsonDict) -> TypeInfo:
31003100
ti.type_vars = data["type_vars"]
31013101
ti.has_param_spec_type = data["has_param_spec_type"]
31023102
ti.bases = [mypy.types.Instance.deserialize(b) for b in data["bases"]]
3103-
ti._promote = [mypy.types.deserialize_type(p) for p in data["_promote"]]
3103+
_promote = []
3104+
for p in data["_promote"]:
3105+
t = mypy.types.deserialize_type(p)
3106+
assert isinstance(t, mypy.types.ProperType)
3107+
_promote.append(t)
3108+
ti._promote = _promote
31043109
ti.declared_metaclass = (
31053110
None
31063111
if data["declared_metaclass"] is None

mypy/semanal.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4945,6 +4945,7 @@ def visit_conditional_expr(self, expr: ConditionalExpr) -> None:
49454945
def visit__promote_expr(self, expr: PromoteExpr) -> None:
49464946
analyzed = self.anal_type(expr.type)
49474947
if analyzed is not None:
4948+
assert isinstance(analyzed, ProperType), "Cannot use type aliases for promotions"
49484949
expr.type = analyzed
49494950

49504951
def visit_yield_expr(self, e: YieldExpr) -> None:

mypy/semanal_classprop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Var,
2323
)
2424
from mypy.options import Options
25-
from mypy.types import Instance, Type
25+
from mypy.types import Instance, ProperType
2626

2727
# Hard coded type promotions (shared between all Python versions).
2828
# These add extra ad-hoc edges to the subtyping relation. For example,
@@ -155,7 +155,7 @@ def add_type_promotion(
155155
This includes things like 'int' being compatible with 'float'.
156156
"""
157157
defn = info.defn
158-
promote_targets: list[Type] = []
158+
promote_targets: list[ProperType] = []
159159
for decorator in defn.decorators:
160160
if isinstance(decorator, CallExpr):
161161
analyzed = decorator.analyzed

test-data/unit/check-recursive-types.test

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -532,7 +532,14 @@ m: A
532532
s: str = n.x # E: Incompatible types in assignment (expression has type "Tuple[A, int]", variable has type "str")
533533
reveal_type(m[0]) # N: Revealed type is "builtins.str"
534534
lst = [m, n]
535-
reveal_type(lst[0]) # N: Revealed type is "Tuple[builtins.object, builtins.object]"
535+
536+
# Unfortunately, join of two recursive types is not very precise.
537+
reveal_type(lst[0]) # N: Revealed type is "builtins.object"
538+
539+
# These just should not crash
540+
lst1 = [m]
541+
lst2 = [m, m]
542+
lst3 = [m, m, m]
536543
[builtins fixtures/tuple.pyi]
537544

538545
[case testMutuallyRecursiveNamedTuplesClasses]
@@ -786,3 +793,18 @@ class B:
786793
y: B.Foo
787794
reveal_type(y) # N: Revealed type is "typing.Sequence[typing.Sequence[...]]"
788795
[builtins fixtures/tuple.pyi]
796+
797+
[case testNoCrashOnRecursiveTupleFallback]
798+
from typing import Union, Tuple
799+
800+
Tree1 = Union[str, Tuple[Tree1]]
801+
Tree2 = Union[str, Tuple[Tree2, Tree2]]
802+
Tree3 = Union[str, Tuple[Tree3, Tree3, Tree3]]
803+
804+
def test1() -> Tree1:
805+
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree1]]")
806+
def test2() -> Tree2:
807+
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree2, Tree2]]")
808+
def test3() -> Tree3:
809+
return 42 # E: Incompatible return value type (got "int", expected "Union[str, Tuple[Tree3, Tree3, Tree3]]")
810+
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)