Skip to content

Commit 7832e1f

Browse files
authored
Speed up make_simplified_union, fix recursive tuple crash (#15128)
Fixes #15192 The following code optimises make_simplified_union in the common case that there are exact duplicates in the union. In this regard, this is similar to #15104 There's a behaviour change in one unit test. I think it's good? We'll see what mypy_primer has to say. To get this to work, I needed to use partial tuple fallbacks in a couple places. These could cause crashes anyway. There were some interesting things going on with recursive type aliases and type state assumptions This is about a 25% speedup on the pydantic codebase and about a 2% speedup on self check (measured with uncompiled mypy)
1 parent bfc1a76 commit 7832e1f

File tree

4 files changed

+91
-95
lines changed

4 files changed

+91
-95
lines changed

mypy/subtypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,7 @@ def visit_instance(self, left: Instance) -> bool:
435435
# dynamic base classes correctly, see #5456.
436436
return not isinstance(self.right, NoneType)
437437
right = self.right
438-
if isinstance(right, TupleType) and mypy.typeops.tuple_fallback(right).type.is_enum:
438+
if isinstance(right, TupleType) and right.partial_fallback.type.is_enum:
439439
return self._is_subtype(left, mypy.typeops.tuple_fallback(right))
440440
if isinstance(right, Instance):
441441
if type_state.is_cached_subtype_check(self._subtype_kind, left, right):
@@ -749,7 +749,9 @@ def visit_tuple_type(self, left: TupleType) -> bool:
749749
# for isinstance(x, tuple), though it's unclear why.
750750
return True
751751
return all(self._is_subtype(li, iter_type) for li in left.items)
752-
elif self._is_subtype(mypy.typeops.tuple_fallback(left), right):
752+
elif self._is_subtype(left.partial_fallback, right) and self._is_subtype(
753+
mypy.typeops.tuple_fallback(left), right
754+
):
753755
return True
754756
return False
755757
elif isinstance(right, TupleType):

mypy/test/testtypes.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -611,10 +611,7 @@ def test_simplified_union_with_mixed_str_literals(self) -> None:
611611
[fx.lit_str1, fx.lit_str2, fx.lit_str3_inst],
612612
UnionType([fx.lit_str1, fx.lit_str2, fx.lit_str3_inst]),
613613
)
614-
self.assert_simplified_union(
615-
[fx.lit_str1, fx.lit_str1, fx.lit_str1_inst],
616-
UnionType([fx.lit_str1, fx.lit_str1_inst]),
617-
)
614+
self.assert_simplified_union([fx.lit_str1, fx.lit_str1, fx.lit_str1_inst], fx.lit_str1)
618615

619616
def assert_simplified_union(self, original: list[Type], union: Type) -> None:
620617
assert_equal(make_simplified_union(original), union)

mypy/typeops.py

Lines changed: 70 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -385,25 +385,6 @@ def callable_corresponding_argument(
385385
return by_name if by_name is not None else by_pos
386386

387387

388-
def simple_literal_value_key(t: ProperType) -> tuple[str, ...] | None:
389-
"""Return a hashable description of simple literal type.
390-
391-
Return None if not a simple literal type.
392-
393-
The return value can be used to simplify away duplicate types in
394-
unions by comparing keys for equality. For now enum, string or
395-
Instance with string last_known_value are supported.
396-
"""
397-
if isinstance(t, LiteralType):
398-
if t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str":
399-
assert isinstance(t.value, str)
400-
return "literal", t.value, t.fallback.type.fullname
401-
if isinstance(t, Instance):
402-
if t.last_known_value is not None and isinstance(t.last_known_value.value, str):
403-
return "instance", t.last_known_value.value, t.type.fullname
404-
return None
405-
406-
407388
def simple_literal_type(t: ProperType | None) -> Instance | None:
408389
"""Extract the underlying fallback Instance type for a simple Literal"""
409390
if isinstance(t, Instance) and t.last_known_value is not None:
@@ -414,7 +395,6 @@ def simple_literal_type(t: ProperType | None) -> Instance | None:
414395

415396

416397
def is_simple_literal(t: ProperType) -> bool:
417-
"""Fast way to check if simple_literal_value_key() would return a non-None value."""
418398
if isinstance(t, LiteralType):
419399
return t.fallback.type.is_enum or t.fallback.type.fullname == "builtins.str"
420400
if isinstance(t, Instance):
@@ -500,68 +480,80 @@ def make_simplified_union(
500480
def _remove_redundant_union_items(items: list[Type], keep_erased: bool) -> list[Type]:
501481
from mypy.subtypes import is_proper_subtype
502482

503-
removed: set[int] = set()
504-
seen: set[tuple[str, ...]] = set()
505-
506-
# NB: having a separate fast path for Union of Literal and slow path for other things
507-
# would arguably be cleaner, however it breaks down when simplifying the Union of two
508-
# different enum types as try_expanding_sum_type_to_union works recursively and will
509-
# trigger intermediate simplifications that would render the fast path useless
510-
for i, item in enumerate(items):
511-
proper_item = get_proper_type(item)
512-
if i in removed:
513-
continue
514-
# Avoid slow nested for loop for Union of Literal of strings/enums (issue #9169)
515-
k = simple_literal_value_key(proper_item)
516-
if k is not None:
517-
if k in seen:
518-
removed.add(i)
483+
# The first pass through this loop, we check if later items are subtypes of earlier items.
484+
# The second pass through this loop, we check if earlier items are subtypes of later items
485+
# (by reversing the remaining items)
486+
for _direction in range(2):
487+
new_items: list[Type] = []
488+
# seen is a map from a type to its index in new_items
489+
seen: dict[ProperType, int] = {}
490+
unduplicated_literal_fallbacks: set[Instance] | None = None
491+
for ti in items:
492+
proper_ti = get_proper_type(ti)
493+
494+
# UninhabitedType is always redundant
495+
if isinstance(proper_ti, UninhabitedType):
519496
continue
520497

521-
# NB: one would naively expect that it would be safe to skip the slow path
522-
# always for literals. One would be sorely mistaken. Indeed, some simplifications
523-
# such as that of None/Optional when strict optional is false, do require that we
524-
# proceed with the slow path. Thankfully, all literals will have the same subtype
525-
# relationship to non-literal types, so we only need to do that walk for the first
526-
# literal, which keeps the fast path fast even in the presence of a mixture of
527-
# literals and other types.
528-
safe_skip = len(seen) > 0
529-
seen.add(k)
530-
if safe_skip:
531-
continue
532-
533-
# Keep track of the truthiness info for deleted subtypes which can be relevant
534-
cbt = cbf = False
535-
for j, tj in enumerate(items):
536-
proper_tj = get_proper_type(tj)
537-
if (
538-
i == j
539-
# avoid further checks if this item was already marked redundant.
540-
or j in removed
541-
# if the current item is a simple literal then this simplification loop can
542-
# safely skip all other simple literals as two literals will only ever be
543-
# subtypes of each other if they are equal, which is already handled above.
544-
# However, if the current item is not a literal, it might plausibly be a
545-
# supertype of other literals in the union, so we must check them again.
546-
# This is an important optimization as is_proper_subtype is pretty expensive.
547-
or (k is not None and is_simple_literal(proper_tj))
548-
):
549-
continue
550-
# actual redundancy checks (XXX?)
551-
if is_redundant_literal_instance(proper_item, proper_tj) and is_proper_subtype(
552-
tj, item, keep_erased_types=keep_erased, ignore_promotions=True
498+
duplicate_index = -1
499+
# Quickly check if we've seen this type
500+
if proper_ti in seen:
501+
duplicate_index = seen[proper_ti]
502+
elif (
503+
isinstance(proper_ti, LiteralType)
504+
and unduplicated_literal_fallbacks is not None
505+
and proper_ti.fallback in unduplicated_literal_fallbacks
553506
):
554-
# We found a redundant item in the union.
555-
removed.add(j)
556-
cbt = cbt or tj.can_be_true
557-
cbf = cbf or tj.can_be_false
558-
# if deleted subtypes had more general truthiness, use that
559-
if not item.can_be_true and cbt:
560-
items[i] = true_or_false(item)
561-
elif not item.can_be_false and cbf:
562-
items[i] = true_or_false(item)
507+
# This is an optimisation for unions with many LiteralType
508+
# We've already checked for exact duplicates. This means that any super type of
509+
# the LiteralType must be a super type of its fallback. If we've gone through
510+
# the expensive loop below and found no super type for a previous LiteralType
511+
# with the same fallback, we can skip doing that work again and just add the type
512+
# to new_items
513+
pass
514+
else:
515+
# If not, check if we've seen a supertype of this type
516+
for j, tj in enumerate(new_items):
517+
tj = get_proper_type(tj)
518+
# If tj is an Instance with a last_known_value, do not remove proper_ti
519+
# (unless it's an instance with the same last_known_value)
520+
if (
521+
isinstance(tj, Instance)
522+
and tj.last_known_value is not None
523+
and not (
524+
isinstance(proper_ti, Instance)
525+
and tj.last_known_value == proper_ti.last_known_value
526+
)
527+
):
528+
continue
529+
530+
if is_proper_subtype(
531+
proper_ti, tj, keep_erased_types=keep_erased, ignore_promotions=True
532+
):
533+
duplicate_index = j
534+
break
535+
if duplicate_index != -1:
536+
# If deleted subtypes had more general truthiness, use that
537+
orig_item = new_items[duplicate_index]
538+
if not orig_item.can_be_true and ti.can_be_true:
539+
new_items[duplicate_index] = true_or_false(orig_item)
540+
elif not orig_item.can_be_false and ti.can_be_false:
541+
new_items[duplicate_index] = true_or_false(orig_item)
542+
else:
543+
# We have a non-duplicate item, add it to new_items
544+
seen[proper_ti] = len(new_items)
545+
new_items.append(ti)
546+
if isinstance(proper_ti, LiteralType):
547+
if unduplicated_literal_fallbacks is None:
548+
unduplicated_literal_fallbacks = set()
549+
unduplicated_literal_fallbacks.add(proper_ti.fallback)
563550

564-
return [items[i] for i in range(len(items)) if i not in removed]
551+
items = new_items
552+
if len(items) <= 1:
553+
break
554+
items.reverse()
555+
556+
return items
565557

566558

567559
def _get_type_special_method_bool_ret_type(t: Type) -> Type | None:
@@ -992,17 +984,6 @@ def custom_special_method(typ: Type, name: str, check_all: bool = False) -> bool
992984
return False
993985

994986

995-
def is_redundant_literal_instance(general: ProperType, specific: ProperType) -> bool:
996-
if not isinstance(general, Instance) or general.last_known_value is None:
997-
return True
998-
if isinstance(specific, Instance) and specific.last_known_value == general.last_known_value:
999-
return True
1000-
if isinstance(specific, UninhabitedType):
1001-
return True
1002-
1003-
return False
1004-
1005-
1006987
def separate_union_literals(t: UnionType) -> tuple[Sequence[LiteralType], Sequence[Type]]:
1007988
"""Separate literals from other members in a union type."""
1008989
literal_items = []

test-data/unit/check-type-aliases.test

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,3 +1043,19 @@ class C(Generic[T]):
10431043
def test(cls) -> None:
10441044
cls.attr
10451045
[builtins fixtures/classmethod.pyi]
1046+
1047+
[case testRecursiveAliasTuple]
1048+
from typing_extensions import Literal, TypeAlias
1049+
from typing import Tuple, Union
1050+
1051+
Expr: TypeAlias = Union[
1052+
Tuple[Literal[123], int],
1053+
Tuple[Literal[456], "Expr"],
1054+
]
1055+
1056+
def eval(e: Expr) -> int:
1057+
if e[0] == 123:
1058+
return e[1]
1059+
elif e[0] == 456:
1060+
return -eval(e[1])
1061+
[builtins fixtures/dict.pyi]

0 commit comments

Comments
 (0)