From 2de93ae9eaeef8f0fb6de920518de59515bb947e Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 3 Jul 2025 14:36:00 +0200 Subject: [PATCH 01/12] inital commit --- mypy/infer.py | 4 ++++ test-data/unit/check-expressions.test | 22 ++++++++++++++++++---- test-data/unit/fixtures/list.pyi | 5 +++-- 3 files changed, 25 insertions(+), 6 deletions(-) diff --git a/mypy/infer.py b/mypy/infer.py index cdc43797d3b1..6fcba279a000 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -73,4 +73,8 @@ def infer_type_arguments( # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) + + # for tp in type_vars: + # constraints.append(Constraint(tp, SUPERTYPE_OF, UninhabitedType())) + return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0] diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..41cc3cf8ddde 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -308,6 +308,24 @@ main:5: error: Unsupported operand types for ^ ("A" and "A") main:6: error: Unsupported operand types for << ("A" and "B") main:7: error: Unsupported operand types for >> ("A" and "A") +[case testBinaryOperatorContext] +from typing import TypeVar, Generic, Iterable, Iterator, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def __init__(self, iterable: Iterable[T], /) -> None: ... + def __iter__(self) -> Iterator[T]: yield from [] + def __add__(self, value: "Vec[S]", /) -> "Vec[Union[S, T]]": return Vec([]) + +def fmt(arg: Iterable[Union[int, str]]) -> None: ... + +l1: Vec[int] = Vec([1]) +l2: Vec[int] = Vec([1]) +fmt(l1 + l2) +[builtins fixtures/list.pyi] + [case testBooleanAndOr] a: A b: bool @@ -460,10 +478,6 @@ class A: def __contains__(self, x: 'A') -> str: pass [builtins fixtures/bool.pyi] -[case testInWithInvalidArgs] -a = 1 in ([1] + ['x']) # E: List item 0 has incompatible type "str"; expected "int" -[builtins fixtures/list.pyi] - [case testEq] a: A b: bool diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 3dcdf18b2faa..032abfc6beed 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -1,8 +1,9 @@ # Builtins stub used in list-related test cases. -from typing import TypeVar, Generic, Iterable, Iterator, Sequence, overload +from typing import TypeVar, Generic, Iterable, Iterator, Sequence, overload, Union T = TypeVar('T') +_S = TypeVar("_S") class object: def __init__(self) -> None: pass @@ -19,7 +20,7 @@ class list(Sequence[T]): def __iter__(self) -> Iterator[T]: pass def __len__(self) -> int: pass def __contains__(self, item: object) -> bool: pass - def __add__(self, x: list[T]) -> list[T]: pass + def __add__(self, x: list[_S]) -> list[Union[_S, T]]: pass def __mul__(self, x: int) -> list[T]: pass def __getitem__(self, x: int) -> T: pass def __setitem__(self, x: int, v: T) -> None: pass From e4238c708a5aaea3cf824130ef6e70936a6f064a Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Thu, 3 Jul 2025 21:24:41 +0200 Subject: [PATCH 02/12] 11 failures remain --- mypy/checkexpr.py | 242 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 229 insertions(+), 13 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..46b3a46612e1 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,6 +18,7 @@ from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker +from mypy.constraints import SUBTYPE_OF, Constraint, infer_constraints from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error from mypy.expandtype import ( @@ -26,7 +27,13 @@ freshen_all_functions_type_vars, freshen_function_type_vars, ) -from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments +from mypy.infer import ( + ArgumentInferContext, + infer_constraints_for_callable, + infer_function_type_arguments, + infer_type_arguments, + solve_constraints, +) from mypy.literals import literal from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_types, narrow_declared_type @@ -1774,18 +1781,18 @@ def check_callable_call( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) + # callee = self.infer_function_type_arguments_using_context(callee, context) + # if need_refresh: + # # Argument kinds etc. may have changed due to + # # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # # number of arguments; recalculate actual-to-formal map + # formal_to_actual = map_actuals_to_formals( + # arg_kinds, + # arg_names, + # callee.arg_kinds, + # callee.arg_names, + # lambda i: self.accept(args[i]), + # ) callee = self.infer_function_type_arguments( callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) @@ -2085,6 +2092,88 @@ def infer_function_type_arguments_using_context( callable, new_args, error_context, skip_unsatisfied=True ) + def _infer_constraints_from_context( + self, callee: CallableType, error_context: Context + ) -> list[Constraint]: + """Unify callable return type to type context to infer type vars. + + For example, if the return type is set[t] where 't' is a type variable + of callable, and if the context is set[int], return callable modified + by substituting 't' with 'int'. + """ + ctx = self.type_context[-1] + if not ctx: + return [] + # The return type may have references to type metavariables that + # we are inferring right now. We must consider them as indeterminate + # and they are not potential results; thus we replace them with the + # special ErasedType type. On the other hand, class type variables are + # valid results. + erased_ctx = replace_meta_vars(ctx, ErasedType()) + ret_type = callee.ret_type + if is_overlapping_none(ret_type) and is_overlapping_none(ctx): + # If both the context and the return type are optional, unwrap the optional, + # since in 99% cases this is what a user expects. In other words, we replace + # Optional[T] <: Optional[int] + # with + # T <: int + # while the former would infer T <: Optional[int]. + ret_type = remove_optional(ret_type) + erased_ctx = remove_optional(erased_ctx) + # + # TODO: Instead of this hack and the one below, we need to use outer and + # inner contexts at the same time. This is however not easy because of two + # reasons: + # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables + # on both sides. (This is not too hard.) + # * We need to update all the inference "infrastructure", so that all + # variables in an expression are inferred at the same time. + # (And this is hard, also we need to be careful with lambdas that require + # two passes.) + proper_ret = get_proper_type(ret_type) + if ( + isinstance(proper_ret, TypeVarType) + or isinstance(proper_ret, UnionType) + and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) + ): + # Another special case: the return type is a type variable. If it's unrestricted, + # we could infer a too general type for the type variable if we use context, + # and this could result in confusing and spurious type errors elsewhere. + # + # So we give up and just use function arguments for type inference, with just two + # exceptions: + # + # 1. If the context is a generic instance type, actually use it as context, as + # this *seems* to usually be the reasonable thing to do. + # + # See also github issues #462 and #360. + # + # 2. If the context is some literal type, we want to "propagate" that information + # down so that we infer a more precise type for literal expressions. For example, + # the expression `3` normally has an inferred type of `builtins.int`: but if it's + # in a literal context like below, we want it to infer `Literal[3]` instead. + # + # def expects_literal(x: Literal[3]) -> None: pass + # def identity(x: T) -> T: return x + # + # expects_literal(identity(3)) # Should type-check + # TODO: we may want to add similar exception if all arguments are lambdas, since + # in this case external context is almost everything we have. + if not is_generic_instance(ctx) and not is_literal_type_like(ctx): + return [] + + constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF) + return constraints + + def _filter_args(self, args: list[Type | None]) -> list[Type | None]: + new_args: list[Type | None] = [] + for arg in args: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + return new_args + def infer_function_type_arguments( self, callee_type: CallableType, @@ -2131,6 +2220,133 @@ def infer_function_type_arguments( context=self.argument_infer_context(), strict=self.chk.in_checked_function(), ) + old_inferred_args = inferred_args + new_inferred_args = None + + if True: # NEW CODE + extra_constraints = self._infer_constraints_from_context(callee_type, context) + + # outer_solution + _outer_solution = solve_constraints( + callee_type.variables, + extra_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1]) + + # inner solution + constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + inner_solution = solve_constraints( + callee_type.variables, + constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + joint_solution = solve_constraints( + callee_type.variables, + extra_constraints + constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + # check if we can use the joint solution, otherwise fallback to outer_solution + for var1, var2 in zip( + outer_solution[0], joint_solution[0] + ): # tuple[Type | None, Type | None] + if var2 is None and var1 is not None: + # using both constraints did not find a solution for this variable + # so we fallback to outer_solution, apply the solution, and then recompute the inner part + use_joint = False + break + else: + use_joint = True + + use_joint = True + use_outer = True + use_inner = True + + for outer_tp, inner_tp, joint_tp in zip( + outer_solution[0], inner_solution[0], joint_solution[0] + ): + if joint_tp is None and outer_tp is not None: + use_joint = False + if has_erased_component(joint_tp) and not has_erased_component(inner_tp): + # If the joint solution is erased, but outer is not, we use outer. + use_joint = False + if has_erased_component(outer_tp) and not has_erased_component(inner_tp): + use_outer = False + if has_erased_component(inner_tp): + use_inner = False + + if use_joint: + new_inferred_args = joint_solution[0] + # inferred_args = [ + # # Usually, joint_tp <: outer_tp (since superset of constraints), + # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]` + # (outer_tp if is_subtype(outer_tp, joint_tp) else joint_tp) + # for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]) + # ] + elif use_outer: + # If we cannot use the joint solution, fallback to outer_solution + new_inferred_args = outer_solution[0] + + # Only substitute non-Uninhabited and non-erased types. + new_args: list[Type | None] = [] + for arg in new_inferred_args: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + callee_type = self.apply_generic_arguments( + callee_type, new_args, context, skip_unsatisfied=True + ) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee_type.arg_kinds, + callee_type.arg_names, + lambda i: self.accept(args[i]), + ) + new_inferred_args, _ = infer_function_type_arguments( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + strict=self.chk.in_checked_function(), + ) + elif use_inner: + new_inferred_args = inner_solution[0] + else: + raise RuntimeError("No solution found for function type arguments") + else: # OLD CODE + pass + + if True: # USE NEW CODE + inferred_args = new_inferred_args + else: # USE OLD CODE + inferred_args = old_inferred_args + + # show me + _1 = new_inferred_args + _2 = old_inferred_args + _3 = inferred_args if 2 in arg_pass_nums: # Second pass of type inference. From db7ed358ba0f10951b0ff5b217a5c793c1d64ba8 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 4 Jul 2025 15:26:16 +0200 Subject: [PATCH 03/12] 3 test remain --- mypy/checkexpr.py | 153 ++++++++++++++++++++++------ test-data/unit/check-functions.test | 2 +- test-data/unit/check-generics.test | 4 +- 3 files changed, 125 insertions(+), 34 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 46b3a46612e1..99aab8ccfcbf 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -121,6 +121,7 @@ from mypy.subtypes import ( find_member, is_equivalent, + is_proper_subtype, is_same_type, is_subtype, non_method_protocol_members, @@ -2093,7 +2094,7 @@ def infer_function_type_arguments_using_context( ) def _infer_constraints_from_context( - self, callee: CallableType, error_context: Context + self, callee: CallableType, error_context: Context, erase: bool = True ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. @@ -2109,7 +2110,10 @@ def _infer_constraints_from_context( # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) + if erase: + erased_ctx = replace_meta_vars(ctx, ErasedType()) + else: + erased_ctx = ctx ret_type = callee.ret_type if is_overlapping_none(ret_type) and is_overlapping_none(ctx): # If both the context and the return type are optional, unwrap the optional, @@ -2168,10 +2172,16 @@ def _infer_constraints_from_context( def _filter_args(self, args: list[Type | None]) -> list[Type | None]: new_args: list[Type | None] = [] for arg in args: - if has_uninhabited_component(arg) or has_erased_component(arg): + if arg is None: new_args.append(None) + continue else: + arg = replace_meta_vars(arg, ErasedType()) new_args.append(arg) + # if has_erased_component(arg) or has_uninhabited_component(arg): + # new_args.append(None) + # else: + # new_args.append(arg) return new_args def infer_function_type_arguments( @@ -2224,18 +2234,6 @@ def infer_function_type_arguments( new_inferred_args = None if True: # NEW CODE - extra_constraints = self._infer_constraints_from_context(callee_type, context) - - # outer_solution - _outer_solution = solve_constraints( - callee_type.variables, - extra_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1]) - - # inner solution constraints = infer_constraints_for_callable( callee_type, pass1_args, @@ -2244,38 +2242,84 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), ) - inner_solution = solve_constraints( + + extra_constraints = self._infer_constraints_from_context( + callee_type, context, erase=False + ) + erased_constraints = self._infer_constraints_from_context( + callee_type, context, erase=True + ) + + _outer_solution = solve_constraints( + callee_type.variables, + extra_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + _inner_solution = solve_constraints( callee_type.variables, constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) + # NOTE: The order of constraints is important here! + # solve(outer + inner) and solve(inner + outer) may yield different results. + _joint_solution = solve_constraints( + callee_type.variables, + constraints + extra_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) - joint_solution = solve_constraints( + _reverse_joint_solution = solve_constraints( callee_type.variables, extra_constraints + constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) - # check if we can use the joint solution, otherwise fallback to outer_solution - for var1, var2 in zip( - outer_solution[0], joint_solution[0] - ): # tuple[Type | None, Type | None] - if var2 is None and var1 is not None: - # using both constraints did not find a solution for this variable - # so we fallback to outer_solution, apply the solution, and then recompute the inner part - use_joint = False - break - else: - use_joint = True + _erased_outer_solution = solve_constraints( + callee_type.variables, + erased_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + _erased_joint_solution = solve_constraints( + callee_type.variables, + constraints + erased_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + _erased_reverse_joint_solution = solve_constraints( + callee_type.variables, + erased_constraints + constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + + # Now, we select the solution to use. + # Note: Since joint uses both outer and inner constraints, + # and solution discovered by joint is also a solution for outer and inner. + # therefore, we can pick either inner or outer as a substitute for joint, + # and then try to solve again using only the inner constraints. + # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1]) + # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1]) + outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1]) + inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1]) + joint_solution = _joint_solution + reverse_joint_solution = _reverse_joint_solution + + target_solution = _erased_reverse_joint_solution use_joint = True use_outer = True use_inner = True - + # check if we can use the joint solution, otherwise fallback to outer_solution for outer_tp, inner_tp, joint_tp in zip( - outer_solution[0], inner_solution[0], joint_solution[0] + outer_solution[0], inner_solution[0], target_solution[0] ): if joint_tp is None and outer_tp is not None: use_joint = False @@ -2287,8 +2331,55 @@ def infer_function_type_arguments( if has_erased_component(inner_tp): use_inner = False + combined_solution = [] + for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]): + if ( + outer_tp is not None + and joint_tp is not None + and is_proper_subtype(outer_tp, joint_tp) + ): + # If outer is a subtype of joint, we can use joint. + combined_solution.append(outer_tp) + else: + # Otherwise, we use joint. + combined_solution.append(joint_tp) + + _num = arg_pass_nums + _c0 = constraints + _c1 = extra_constraints + _c2 = erased_constraints + + _x0 = _outer_solution[0] + _x2 = _inner_solution[0] + _x3 = _joint_solution[0] + _x4 = _reverse_joint_solution[0] + + _e0 = _erased_outer_solution[0] + _e2 = _erased_joint_solution[0] + _e3 = _erased_reverse_joint_solution[0] + + _s1 = outer_solution[0] + _s2 = inner_solution[0] + _s3 = joint_solution[0] + _s4 = reverse_joint_solution[0] + _s5 = combined_solution + + _u0 = use_inner, use_outer, use_joint + + # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) + if all( + (joint_tp is None and outer_tp is None) + or ( + (joint_tp is not None and outer_tp is not None) + and is_subtype(outer_tp, joint_tp) + ) + for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) + ): + use_joint = False + use_outer = True + if use_joint: - new_inferred_args = joint_solution[0] + new_inferred_args = target_solution[0] # inferred_args = [ # # Usually, joint_tp <: outer_tp (since superset of constraints), # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]` diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 7fa34a398ea0..f8887bf410bc 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`4, y: S`5) -> Union[T`4, S`5]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`16, y: S`17) -> Union[T`16, S`17]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 0be9d918c69f..ebfc97bdb8e2 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`11) -> builtins.list[S`11]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`13) -> builtins.list[S`13]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`23) -> builtins.list[S`23]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`25) -> builtins.list[S`25]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] From 40f8ca8617249afbd3e06626f46a32cf9cf2ab2c Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Fri, 4 Jul 2025 16:19:50 +0200 Subject: [PATCH 04/12] 1 check remaining --- mypy/checkexpr.py | 27 +++++++++++------------ test-data/unit/check-generics.test | 4 ++-- test-data/unit/check-recursive-types.test | 20 +++++++++++++---- 3 files changed, 31 insertions(+), 20 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 99aab8ccfcbf..e6771d4ab56f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2344,6 +2344,18 @@ def infer_function_type_arguments( # Otherwise, we use joint. combined_solution.append(joint_tp) + # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) + if all( + (joint_tp is None and outer_tp is None) + or ( + (joint_tp is not None and outer_tp is not None) + and is_subtype(outer_tp, joint_tp) + ) + for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) + ): + use_joint = False + use_outer = True + _num = arg_pass_nums _c0 = constraints _c1 = extra_constraints @@ -2362,22 +2374,9 @@ def infer_function_type_arguments( _s2 = inner_solution[0] _s3 = joint_solution[0] _s4 = reverse_joint_solution[0] - _s5 = combined_solution - + _t0 = target_solution[0] _u0 = use_inner, use_outer, use_joint - # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) - if all( - (joint_tp is None and outer_tp is None) - or ( - (joint_tp is not None and outer_tp is not None) - and is_subtype(outer_tp, joint_tp) - ) - for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) - ): - use_joint = False - use_outer = True - if use_joint: new_inferred_args = target_solution[0] # inferred_args = [ diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index ebfc97bdb8e2..89c47f2c56cf 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`23) -> builtins.list[S`23]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`25) -> builtins.list[S`25]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`35) -> builtins.list[S`35]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`37) -> builtins.list[S`37]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 7ed5ea53c27e..2b22bef99433 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -285,21 +285,33 @@ if isinstance(b[0], Sequence): [case testRecursiveAliasWithRecursiveInstance] from typing import Sequence, Union, TypeVar -class A: ... T = TypeVar("T") Nested = Sequence[Union[T, Nested[T]]] +def join(a: T, b: T) -> T: ... + +class A: ... class B(Sequence[B]): ... a: Nested[A] aa: Nested[A] b: B + a = b # OK +reveal_type(a) # N: Revealed type is "__main__.B" + a = [[b]] # OK +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[__main__.B]]" + b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") +reveal_type(b) # N: Revealed type is "__main__.B" + +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" + +def test(a: Nested[A], b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" + reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -def join(a: T, b: T) -> T: ... -reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasWithRecursiveInstanceInference] From d7145d944ff1224a4967624f7bf184e7a28ae750 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Sat, 5 Jul 2025 23:41:32 +0200 Subject: [PATCH 05/12] checks pass --- mypy/checkexpr.py | 112 ++++++++++++++++++++++++++-- test-data/unit/check-functions.test | 2 +- test-data/unit/check-generics.test | 4 +- 3 files changed, 107 insertions(+), 11 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e6771d4ab56f..cab0506975e8 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2134,6 +2134,14 @@ def _infer_constraints_from_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) + # ret_as_union = make_simplified_union([ret_type]) + # erased_ctx_as_union = make_simplified_union([ctx]) + # if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType): + # new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items] + # new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items] + # ret_type = make_simplified_union(new_ret) + # erased_ctx = make_simplified_union(new_ctx) + proper_ret = get_proper_type(ret_type) if ( isinstance(proper_ret, TypeVarType) @@ -2184,6 +2192,33 @@ def _filter_args(self, args: list[Type | None]) -> list[Type | None]: # new_args.append(arg) return new_args + def intersect_solutions(self, sol1: list[Type | None], sol2: list[Type | None]): + # first, ensure that the None-patterns agree + assert len(sol1) == len(sol2) + + virtual_vars = [] + constraints = [] + + for i, (tp1, tp2) in enumerate(zip(sol1, sol2)): + new_id = TypeVarId.new(-1) + name = f"V{i}" + new_tvar = TypeVarType( + name, + name, + new_id, + values=[], + upper_bound=self.object_type(), + default=AnyType(TypeOfAny.from_omitted_generics), + ) + virtual_vars.append(new_tvar) + if tp1 is not None: + c1 = Constraint(new_tvar, SUBTYPE_OF, tp1) + constraints.append(c1) + if tp2 is not None: + c2 = Constraint(new_tvar, SUBTYPE_OF, tp2) + constraints.append(c2) + return virtual_vars, constraints + def infer_function_type_arguments( self, callee_type: CallableType, @@ -2311,9 +2346,37 @@ def infer_function_type_arguments( inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1]) joint_solution = _joint_solution reverse_joint_solution = _reverse_joint_solution - target_solution = _erased_reverse_joint_solution + if True: # compute the outer and target return types. + # Only substitute non-Uninhabited and non-erased types. + new_args: list[Type | None] = [] + for arg in outer_solution[0]: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + outer_callee = self.apply_generic_arguments( + callee_type, new_args, context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) + + # Only substitute non-Uninhabited and non-erased types. + new_args: list[Type | None] = [] + for arg in target_solution[0]: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + target_callee = self.apply_generic_arguments( + callee_type, new_args, context, skip_unsatisfied=True + ) + target_ret_type = get_proper_type(target_callee.ret_type) + use_joint = True use_outer = True use_inner = True @@ -2344,18 +2407,46 @@ def infer_function_type_arguments( # Otherwise, we use joint. combined_solution.append(joint_tp) + # new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0]) + # intersected_solution = solve_constraints( + # new_vars, + # new_constraints, + # strict=self.chk.in_checked_function(), + # allow_polymorphic=False, + # ) + # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) - if all( - (joint_tp is None and outer_tp is None) - or ( - (joint_tp is not None and outer_tp is not None) - and is_subtype(outer_tp, joint_tp) - ) - for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) + # if all( + # (joint_tp is None and outer_tp is None) + # or ( + # (joint_tp is not None and outer_tp is not None) + # and ( + # is_subtype(outer_tp, joint_tp) + # or ( + # isinstance(outer_tp, UnionType) + # and any(is_subtype(val, joint_tp) for val in outer_tp.items) + # ) + # ) + # ) + # for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) + # ): + # use_joint = False + # use_outer = True + + # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) + if is_subtype(outer_ret_type, target_ret_type) or ( + isinstance(outer_ret_type, UnionType) + and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items) ): use_joint = False use_outer = True + # what if the outer context is a union type? + # we may have a case like: + # outer : int | Literal["foo"] + # inner: Literal["foo"]? (which gets translated into str later) + # here, we would want `Literal["foo"]` to be used as the solution, + _num = arg_pass_nums _c0 = constraints _c1 = extra_constraints @@ -2375,6 +2466,11 @@ def infer_function_type_arguments( _s3 = joint_solution[0] _s4 = reverse_joint_solution[0] _t0 = target_solution[0] + + _r1 = outer_ret_type + _r2 = target_ret_type + _y1 = outer_callee + _y2 = target_callee _u0 = use_inner, use_outer, use_joint if use_joint: diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index f8887bf410bc..8a49497e227b 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`16, y: S`17) -> Union[T`16, S`17]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`14, y: S`15) -> Union[T`14, S`15]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 89c47f2c56cf..b6f6352f270a 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`35) -> builtins.list[S`35]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`37) -> builtins.list[S`37]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`31) -> builtins.list[S`31]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`33) -> builtins.list[S`33]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] From ce4956d85c3daa526885d628e9b01b3cc4242d5d Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Mon, 7 Jul 2025 11:23:03 +0200 Subject: [PATCH 06/12] some simplifications --- mypy/checkexpr.py | 273 ++++++++++++---------------- test-data/unit/check-functions.test | 2 +- test-data/unit/check-generics.test | 4 +- 3 files changed, 118 insertions(+), 161 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index cab0506975e8..e60baf3765df 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -121,7 +121,6 @@ from mypy.subtypes import ( find_member, is_equivalent, - is_proper_subtype, is_same_type, is_subtype, non_method_protocol_members, @@ -2094,7 +2093,7 @@ def infer_function_type_arguments_using_context( ) def _infer_constraints_from_context( - self, callee: CallableType, error_context: Context, erase: bool = True + self, callee: CallableType, error_context: Context ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. @@ -2105,47 +2104,35 @@ def _infer_constraints_from_context( ctx = self.type_context[-1] if not ctx: return [] - # The return type may have references to type metavariables that - # we are inferring right now. We must consider them as indeterminate - # and they are not potential results; thus we replace them with the - # special ErasedType type. On the other hand, class type variables are - # valid results. - if erase: - erased_ctx = replace_meta_vars(ctx, ErasedType()) - else: - erased_ctx = ctx + # if is_overlapping_none(ret_type) and is_overlapping_none(ctx): + # # If both the context and the return type are optional, unwrap the optional, + # # since in 99% cases this is what a user expects. In other words, we replace + # # Optional[T] <: Optional[int] + # # with + # # T <: int + # # while the former would infer T <: Optional[int]. + # ret_type = remove_optional(ret_type) + # erased_ctx = remove_optional(erased_ctx) + # # + # # TODO: Instead of this hack and the one below, we need to use outer and + # # inner contexts at the same time. This is however not easy because of two + # # reasons: + # # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables + # # on both sides. (This is not too hard.) + # # * We need to update all the inference "infrastructure", so that all + # # variables in an expression are inferred at the same time. + # # (And this is hard, also we need to be careful with lambdas that require + # # two passes.) ret_type = callee.ret_type - if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # If both the context and the return type are optional, unwrap the optional, - # since in 99% cases this is what a user expects. In other words, we replace - # Optional[T] <: Optional[int] - # with - # T <: int - # while the former would infer T <: Optional[int]. - ret_type = remove_optional(ret_type) - erased_ctx = remove_optional(erased_ctx) - # - # TODO: Instead of this hack and the one below, we need to use outer and - # inner contexts at the same time. This is however not easy because of two - # reasons: - # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables - # on both sides. (This is not too hard.) - # * We need to update all the inference "infrastructure", so that all - # variables in an expression are inferred at the same time. - # (And this is hard, also we need to be careful with lambdas that require - # two passes.) - # ret_as_union = make_simplified_union([ret_type]) - # erased_ctx_as_union = make_simplified_union([ctx]) - # if isinstance(ret_as_union, UnionType) and isinstance(erased_ctx_as_union, UnionType): - # new_ret = [val for val in ret_as_union.items if val not in erased_ctx_as_union.items] - # new_ctx = [val for val in erased_ctx_as_union.items if val not in ret_as_union.items] - # ret_type = make_simplified_union(new_ret) - # erased_ctx = make_simplified_union(new_ctx) + if isinstance(ret_type, UnionType) and isinstance(ctx, UnionType): + new_ret = [val for val in ret_type.items if val not in ctx.items] + new_ctx = [val for val in ctx.items if val not in ret_type.items] + ret_type = make_simplified_union(new_ret) + ctx = make_simplified_union(new_ctx) proper_ret = get_proper_type(ret_type) - if ( - isinstance(proper_ret, TypeVarType) - or isinstance(proper_ret, UnionType) + if isinstance(proper_ret, TypeVarType) or ( + isinstance(proper_ret, UnionType) and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) ): # Another special case: the return type is a type variable. If it's unrestricted, @@ -2174,6 +2161,12 @@ def _infer_constraints_from_context( if not is_generic_instance(ctx) and not is_literal_type_like(ctx): return [] + # The return type may have references to type metavariables that + # we are inferring right now. We must consider them as indeterminate + # and they are not potential results; thus we replace them with the + # special ErasedType type. On the other hand, class type variables are + # valid results. + erased_ctx = replace_meta_vars(ctx, ErasedType()) constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF) return constraints @@ -2278,12 +2271,7 @@ def infer_function_type_arguments( context=self.argument_infer_context(), ) - extra_constraints = self._infer_constraints_from_context( - callee_type, context, erase=False - ) - erased_constraints = self._infer_constraints_from_context( - callee_type, context, erase=True - ) + extra_constraints = self._infer_constraints_from_context(callee_type, context) _outer_solution = solve_constraints( callee_type.variables, @@ -2314,27 +2302,6 @@ def infer_function_type_arguments( allow_polymorphic=False, ) - _erased_outer_solution = solve_constraints( - callee_type.variables, - erased_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - - _erased_joint_solution = solve_constraints( - callee_type.variables, - constraints + erased_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - - _erased_reverse_joint_solution = solve_constraints( - callee_type.variables, - erased_constraints + constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - # Now, we select the solution to use. # Note: Since joint uses both outer and inner constraints, # and solution discovered by joint is also a solution for outer and inner. @@ -2342,104 +2309,99 @@ def infer_function_type_arguments( # and then try to solve again using only the inner constraints. # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1]) # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1]) - outer_solution = (self._filter_args(_outer_solution[0]), _outer_solution[1]) - inner_solution = (self._filter_args(_inner_solution[0]), _inner_solution[1]) + outer_solution = _outer_solution + inner_solution = _inner_solution joint_solution = _joint_solution reverse_joint_solution = _reverse_joint_solution - target_solution = _erased_reverse_joint_solution + target_solution = _reverse_joint_solution if True: # compute the outer and target return types. - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in outer_solution[0]: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - outer_callee = self.apply_generic_arguments( - callee_type, new_args, context, skip_unsatisfied=True - ) - outer_ret_type = get_proper_type(outer_callee.ret_type) + if True: + outer_callee = self.apply_generic_arguments( + callee_type, outer_solution[0], context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in target_solution[0]: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - target_callee = self.apply_generic_arguments( - callee_type, new_args, context, skip_unsatisfied=True - ) - target_ret_type = get_proper_type(target_callee.ret_type) + target_callee = self.apply_generic_arguments( + callee_type, target_solution[0], context, skip_unsatisfied=True + ) + target_ret_type = get_proper_type(target_callee.ret_type) + else: + # Only substitute non-Uninhabited and non-erased types. + new_args: list[Type | None] = [] + for arg in outer_solution[0]: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + outer_callee = self.apply_generic_arguments( + callee_type, new_args, context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) + + # Only substitute non-Uninhabited and non-erased types. + new_args: list[Type | None] = [] + for arg in target_solution[0]: + if has_uninhabited_component(arg) or has_erased_component(arg): + new_args.append(None) + else: + new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + target_callee = self.apply_generic_arguments( + callee_type, new_args, context, skip_unsatisfied=True + ) + target_ret_type = get_proper_type(target_callee.ret_type) use_joint = True use_outer = True use_inner = True # check if we can use the joint solution, otherwise fallback to outer_solution - for outer_tp, inner_tp, joint_tp in zip( - outer_solution[0], inner_solution[0], target_solution[0] - ): - if joint_tp is None and outer_tp is not None: - use_joint = False - if has_erased_component(joint_tp) and not has_erased_component(inner_tp): - # If the joint solution is erased, but outer is not, we use outer. - use_joint = False - if has_erased_component(outer_tp) and not has_erased_component(inner_tp): - use_outer = False - if has_erased_component(inner_tp): - use_inner = False - - combined_solution = [] - for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]): - if ( - outer_tp is not None - and joint_tp is not None - and is_proper_subtype(outer_tp, joint_tp) - ): - # If outer is a subtype of joint, we can use joint. - combined_solution.append(outer_tp) - else: - # Otherwise, we use joint. - combined_solution.append(joint_tp) - - # new_vars, new_constraints = self.intersect_solutions(outer_solution[0], target_solution[0]) - # intersected_solution = solve_constraints( - # new_vars, - # new_constraints, - # strict=self.chk.in_checked_function(), - # allow_polymorphic=False, - # ) - - # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) - # if all( - # (joint_tp is None and outer_tp is None) - # or ( - # (joint_tp is not None and outer_tp is not None) - # and ( - # is_subtype(outer_tp, joint_tp) - # or ( - # isinstance(outer_tp, UnionType) - # and any(is_subtype(val, joint_tp) for val in outer_tp.items) - # ) - # ) - # ) - # for outer_tp, joint_tp in zip(outer_solution[0], target_solution[0]) + # for outer_tp, inner_tp, joint_tp in zip( + # outer_solution[0], inner_solution[0], target_solution[0] # ): - # use_joint = False - # use_outer = True + # if joint_tp is None and outer_tp is not None: + # use_joint = False + # if has_erased_component(joint_tp) and not has_erased_component(inner_tp): + # # If the joint solution is erased, but outer is not, we use outer. + # use_joint = False + # if has_erased_component(outer_tp) and not has_erased_component(inner_tp): + # use_outer = False + # if has_erased_component(inner_tp): + # use_inner = False + + if any(tp is None for tp in inner_solution[0]): + use_inner = False + if any(tp is None for tp in outer_solution[0]): + use_outer = False + if any(tp is None for tp in joint_solution[0]): + use_joint = False - # if the outer solution is more concrete than the joint solution, use the outer solution (2 step) - if is_subtype(outer_ret_type, target_ret_type) or ( - isinstance(outer_ret_type, UnionType) - and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items) + if ( + # if the joint failed to solve use the outer solution instead. + # any(joint_tp is None and outer_tp is not None for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0])) + # If the outer solution is more concrete than the joint solution, use the outer solution. + # This also applies if the outer solution is a union type where at least one member + # is a subtype of the target return type. + is_subtype(outer_ret_type, target_ret_type) + or ( + isinstance(outer_ret_type, UnionType) + and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items) + ) ): use_joint = False - use_outer = True + # use_outer = True + + if use_joint: + target_solution = reverse_joint_solution + elif use_outer: + target_solution = outer_solution + elif use_inner: + target_solution = inner_solution + else: + raise AssertionError # what if the outer context is a union type? # we may have a case like: @@ -2450,17 +2412,12 @@ def infer_function_type_arguments( _num = arg_pass_nums _c0 = constraints _c1 = extra_constraints - _c2 = erased_constraints _x0 = _outer_solution[0] _x2 = _inner_solution[0] _x3 = _joint_solution[0] _x4 = _reverse_joint_solution[0] - _e0 = _erased_outer_solution[0] - _e2 = _erased_joint_solution[0] - _e3 = _erased_reverse_joint_solution[0] - _s1 = outer_solution[0] _s2 = inner_solution[0] _s3 = joint_solution[0] @@ -2471,7 +2428,7 @@ def infer_function_type_arguments( _r2 = target_ret_type _y1 = outer_callee _y2 = target_callee - _u0 = use_inner, use_outer, use_joint + _u0 = use_outer, use_joint if use_joint: new_inferred_args = target_solution[0] @@ -2483,7 +2440,7 @@ def infer_function_type_arguments( # ] elif use_outer: # If we cannot use the joint solution, fallback to outer_solution - new_inferred_args = outer_solution[0] + new_inferred_args = target_solution[0] # Only substitute non-Uninhabited and non-erased types. new_args: list[Type | None] = [] diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 8a49497e227b..8d053c786be8 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`14, y: S`15) -> Union[T`14, S`15]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`12, y: S`13) -> Union[T`12, S`13]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index b6f6352f270a..73cbe11ae48e 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`31) -> builtins.list[S`31]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`33) -> builtins.list[S`33]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`27) -> builtins.list[S`27]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`29) -> builtins.list[S`29]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] From e704f6e740aaa153b00f010d8d2c781c12e83189 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Tue, 8 Jul 2025 11:16:39 +0200 Subject: [PATCH 07/12] revert fixture modification --- mypy/checkexpr.py | 202 ++++++++----------------------- mypy/infer.py | 10 +- test-data/unit/fixtures/list.pyi | 5 +- 3 files changed, 58 insertions(+), 159 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e60baf3765df..7fbcdf145dad 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -1781,18 +1781,6 @@ def check_callable_call( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) callee = freshen_function_type_vars(callee) - # callee = self.infer_function_type_arguments_using_context(callee, context) - # if need_refresh: - # # Argument kinds etc. may have changed due to - # # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # # number of arguments; recalculate actual-to-formal map - # formal_to_actual = map_actuals_to_formals( - # arg_kinds, - # arg_names, - # callee.arg_kinds, - # callee.arg_names, - # lambda i: self.accept(args[i]), - # ) callee = self.infer_function_type_arguments( callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) @@ -2125,6 +2113,11 @@ def _infer_constraints_from_context( # # two passes.) ret_type = callee.ret_type if isinstance(ret_type, UnionType) and isinstance(ctx, UnionType): + # If both the context and the return type are unions, we simplify shared items + # e.g. T | None <: int | None => T <: int + # since the former would infer T <: int | None. + # whereas the latter would infer the more precise T <: int. + new_ret = [val for val in ret_type.items if val not in ctx.items] new_ctx = [val for val in ctx.items if val not in ret_type.items] ret_type = make_simplified_union(new_ret) @@ -2170,48 +2163,6 @@ def _infer_constraints_from_context( constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF) return constraints - def _filter_args(self, args: list[Type | None]) -> list[Type | None]: - new_args: list[Type | None] = [] - for arg in args: - if arg is None: - new_args.append(None) - continue - else: - arg = replace_meta_vars(arg, ErasedType()) - new_args.append(arg) - # if has_erased_component(arg) or has_uninhabited_component(arg): - # new_args.append(None) - # else: - # new_args.append(arg) - return new_args - - def intersect_solutions(self, sol1: list[Type | None], sol2: list[Type | None]): - # first, ensure that the None-patterns agree - assert len(sol1) == len(sol2) - - virtual_vars = [] - constraints = [] - - for i, (tp1, tp2) in enumerate(zip(sol1, sol2)): - new_id = TypeVarId.new(-1) - name = f"V{i}" - new_tvar = TypeVarType( - name, - name, - new_id, - values=[], - upper_bound=self.object_type(), - default=AnyType(TypeOfAny.from_omitted_generics), - ) - virtual_vars.append(new_tvar) - if tp1 is not None: - c1 = Constraint(new_tvar, SUBTYPE_OF, tp1) - constraints.append(c1) - if tp2 is not None: - c2 = Constraint(new_tvar, SUBTYPE_OF, tp2) - constraints.append(c2) - return virtual_vars, constraints - def infer_function_type_arguments( self, callee_type: CallableType, @@ -2258,11 +2209,9 @@ def infer_function_type_arguments( context=self.argument_infer_context(), strict=self.chk.in_checked_function(), ) - old_inferred_args = inferred_args - new_inferred_args = None if True: # NEW CODE - constraints = infer_constraints_for_callable( + inner_constraints = infer_constraints_for_callable( callee_type, pass1_args, arg_kinds, @@ -2271,49 +2220,48 @@ def infer_function_type_arguments( context=self.argument_infer_context(), ) - extra_constraints = self._infer_constraints_from_context(callee_type, context) + outer_constraints = self._infer_constraints_from_context(callee_type, context) - _outer_solution = solve_constraints( + outer_solution = solve_constraints( callee_type.variables, - extra_constraints, + outer_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) - _inner_solution = solve_constraints( + inner_solution = solve_constraints( callee_type.variables, - constraints, + inner_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) # NOTE: The order of constraints is important here! # solve(outer + inner) and solve(inner + outer) may yield different results. - _joint_solution = solve_constraints( + + joint_solution = solve_constraints( callee_type.variables, - constraints + extra_constraints, + outer_constraints + inner_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) - _reverse_joint_solution = solve_constraints( + reverse_joint_solution = solve_constraints( callee_type.variables, - extra_constraints + constraints, + inner_constraints + outer_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) + target_solution = joint_solution + # Now, we select the solution to use. # Note: Since joint uses both outer and inner constraints, # and solution discovered by joint is also a solution for outer and inner. # therefore, we can pick either inner or outer as a substitute for joint, # and then try to solve again using only the inner constraints. - # joint_solution = (self._filter_args(_joint_solution[0]), _joint_solution[1]) - # reverse_joint_solution = (self._filter_args(_reverse_joint_solution[0]), _reverse_joint_solution[1]) - outer_solution = _outer_solution - inner_solution = _inner_solution - joint_solution = _joint_solution - reverse_joint_solution = _reverse_joint_solution - target_solution = _reverse_joint_solution + use_joint = True + use_outer = True + use_inner = True if True: # compute the outer and target return types. if True: @@ -2355,44 +2303,17 @@ def infer_function_type_arguments( ) target_ret_type = get_proper_type(target_callee.ret_type) - use_joint = True - use_outer = True - use_inner = True - # check if we can use the joint solution, otherwise fallback to outer_solution - # for outer_tp, inner_tp, joint_tp in zip( - # outer_solution[0], inner_solution[0], target_solution[0] - # ): - # if joint_tp is None and outer_tp is not None: - # use_joint = False - # if has_erased_component(joint_tp) and not has_erased_component(inner_tp): - # # If the joint solution is erased, but outer is not, we use outer. - # use_joint = False - # if has_erased_component(outer_tp) and not has_erased_component(inner_tp): - # use_outer = False - # if has_erased_component(inner_tp): - # use_inner = False - - if any(tp is None for tp in inner_solution[0]): - use_inner = False - if any(tp is None for tp in outer_solution[0]): - use_outer = False - if any(tp is None for tp in joint_solution[0]): - use_joint = False - if ( - # if the joint failed to solve use the outer solution instead. - # any(joint_tp is None and outer_tp is not None for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0])) + # joint constraints failed to produce a complete solution + None in joint_solution[0] # If the outer solution is more concrete than the joint solution, use the outer solution. - # This also applies if the outer solution is a union type where at least one member - # is a subtype of the target return type. - is_subtype(outer_ret_type, target_ret_type) - or ( + or is_subtype(outer_ret_type, target_ret_type) + or ( # HACK to fix testLiteralAndGenericWithUnion isinstance(outer_ret_type, UnionType) and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items) ) ): use_joint = False - # use_outer = True if use_joint: target_solution = reverse_joint_solution @@ -2403,52 +2324,38 @@ def infer_function_type_arguments( else: raise AssertionError - # what if the outer context is a union type? - # we may have a case like: - # outer : int | Literal["foo"] - # inner: Literal["foo"]? (which gets translated into str later) - # here, we would want `Literal["foo"]` to be used as the solution, - - _num = arg_pass_nums - _c0 = constraints - _c1 = extra_constraints - - _x0 = _outer_solution[0] - _x2 = _inner_solution[0] - _x3 = _joint_solution[0] - _x4 = _reverse_joint_solution[0] - - _s1 = outer_solution[0] - _s2 = inner_solution[0] - _s3 = joint_solution[0] - _s4 = reverse_joint_solution[0] - _t0 = target_solution[0] - - _r1 = outer_ret_type - _r2 = target_ret_type - _y1 = outer_callee - _y2 = target_callee - _u0 = use_outer, use_joint + if __debug__: + _num = arg_pass_nums + _c0 = inner_constraints + _c1 = outer_constraints + + _s1 = outer_solution[0] + _s2 = inner_solution[0] + _s3 = joint_solution[0] + _s4 = reverse_joint_solution[0] + _t0 = target_solution[0] + + _r1 = outer_ret_type + _r2 = target_ret_type + _y1 = outer_callee + _y2 = target_callee + _u0 = use_outer, use_joint if use_joint: - new_inferred_args = target_solution[0] - # inferred_args = [ - # # Usually, joint_tp <: outer_tp (since superset of constraints), - # # fixes some cases where we would get `Literal[4]?` rather than `Literal[4]` - # (outer_tp if is_subtype(outer_tp, joint_tp) else joint_tp) - # for outer_tp, joint_tp in zip(outer_solution[0], joint_solution[0]) - # ] + # inferred_args = target_solution[0] + pass elif use_outer: # If we cannot use the joint solution, fallback to outer_solution - new_inferred_args = target_solution[0] + inferred_args = target_solution[0] # Only substitute non-Uninhabited and non-erased types. new_args: list[Type | None] = [] - for arg in new_inferred_args: + for arg in inferred_args: if has_uninhabited_component(arg) or has_erased_component(arg): new_args.append(None) else: new_args.append(arg) + # Don't show errors after we have only used the outer context for inference. # We will use argument context to infer more variables. callee_type = self.apply_generic_arguments( @@ -2465,7 +2372,7 @@ def infer_function_type_arguments( callee_type.arg_names, lambda i: self.accept(args[i]), ) - new_inferred_args, _ = infer_function_type_arguments( + inferred_args, _ = infer_function_type_arguments( callee_type, pass1_args, arg_kinds, @@ -2475,21 +2382,10 @@ def infer_function_type_arguments( strict=self.chk.in_checked_function(), ) elif use_inner: - new_inferred_args = inner_solution[0] + # inferred_args = inner_solution[0] + pass else: raise RuntimeError("No solution found for function type arguments") - else: # OLD CODE - pass - - if True: # USE NEW CODE - inferred_args = new_inferred_args - else: # USE OLD CODE - inferred_args = old_inferred_args - - # show me - _1 = new_inferred_args - _2 = old_inferred_args - _3 = inferred_args if 2 in arg_pass_nums: # Second pass of type inference. diff --git a/mypy/infer.py b/mypy/infer.py index 6fcba279a000..0c3860a1a2fc 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -13,7 +13,7 @@ ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarLikeType +from mypy.types import CallableType, Instance, Type, TypeVarLikeType, UninhabitedType class ArgumentInferContext(NamedTuple): @@ -63,6 +63,9 @@ def infer_function_type_arguments( return solve_constraints(type_vars, constraints, strict, allow_polymorphic) +from mypy.constraints import Constraint + + def infer_type_arguments( type_vars: Sequence[TypeVarLikeType], template: Type, @@ -74,7 +77,8 @@ def infer_type_arguments( # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - # for tp in type_vars: - # constraints.append(Constraint(tp, SUPERTYPE_OF, UninhabitedType())) + # Not needed?! + for tp in type_vars: + constraints.append(Constraint(tp, SUPERTYPE_OF, UninhabitedType())) return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0] diff --git a/test-data/unit/fixtures/list.pyi b/test-data/unit/fixtures/list.pyi index 032abfc6beed..3dcdf18b2faa 100644 --- a/test-data/unit/fixtures/list.pyi +++ b/test-data/unit/fixtures/list.pyi @@ -1,9 +1,8 @@ # Builtins stub used in list-related test cases. -from typing import TypeVar, Generic, Iterable, Iterator, Sequence, overload, Union +from typing import TypeVar, Generic, Iterable, Iterator, Sequence, overload T = TypeVar('T') -_S = TypeVar("_S") class object: def __init__(self) -> None: pass @@ -20,7 +19,7 @@ class list(Sequence[T]): def __iter__(self) -> Iterator[T]: pass def __len__(self) -> int: pass def __contains__(self, item: object) -> bool: pass - def __add__(self, x: list[_S]) -> list[Union[_S, T]]: pass + def __add__(self, x: list[T]) -> list[T]: pass def __mul__(self, x: int) -> list[T]: pass def __getitem__(self, x: int) -> T: pass def __setitem__(self, x: int, v: T) -> None: pass From 26e8da26e52b13fa85ac9b10537dbf44e86f9360 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Tue, 8 Jul 2025 11:34:02 +0200 Subject: [PATCH 08/12] simplify --- mypy/checkexpr.py | 352 ++++++++------------------ mypy/infer.py | 10 +- test-data/unit/check-expressions.test | 4 + test-data/unit/check-functions.test | 2 +- test-data/unit/check-generics.test | 4 +- 5 files changed, 112 insertions(+), 260 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 7fbcdf145dad..f9374ff5be6f 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,7 +18,12 @@ from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker -from mypy.constraints import SUBTYPE_OF, Constraint, infer_constraints +from mypy.constraints import ( + SUBTYPE_OF, + Constraint, + infer_constraints, + infer_constraints_for_callable, +) from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error from mypy.expandtype import ( @@ -27,13 +32,7 @@ freshen_all_functions_type_vars, freshen_function_type_vars, ) -from mypy.infer import ( - ArgumentInferContext, - infer_constraints_for_callable, - infer_function_type_arguments, - infer_type_arguments, - solve_constraints, -) +from mypy.infer import ArgumentInferContext, infer_function_type_arguments from mypy.literals import literal from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_types, narrow_declared_type @@ -117,6 +116,7 @@ Plugin, ) from mypy.semanal_enum import ENUM_BASES +from mypy.solve import solve_constraints from mypy.state import state from mypy.subtypes import ( find_member, @@ -198,12 +198,7 @@ is_named_instance, split_with_prefix_and_suffix, ) -from mypy.types_utils import ( - is_generic_instance, - is_overlapping_none, - is_self_type_like, - remove_optional, -) +from mypy.types_utils import is_generic_instance, is_self_type_like, remove_optional from mypy.typestate import type_state from mypy.typevars import fill_typevars from mypy.util import split_module_names @@ -1995,9 +1990,9 @@ def infer_arg_types_in_context( assert all(tp is not None for tp in res) return cast(list[Type], res) - def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context - ) -> CallableType: + def infer_constraints_from_context( + self, callee: CallableType, error_context: Context + ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. For example, if the return type is set[t] where 't' is a type variable @@ -2006,23 +2001,23 @@ def infer_function_type_arguments_using_context( """ ctx = self.type_context[-1] if not ctx: - return callable + return [] # The return type may have references to type metavariables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) - ret_type = callable.ret_type - if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # If both the context and the return type are optional, unwrap the optional, - # since in 99% cases this is what a user expects. In other words, we replace - # Optional[T] <: Optional[int] - # with - # T <: int - # while the former would infer T <: Optional[int]. - ret_type = remove_optional(ret_type) - erased_ctx = remove_optional(erased_ctx) + erased_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType())) + proper_ret = get_proper_type(callee.ret_type) + if isinstance(proper_ret, UnionType) and isinstance(erased_ctx, UnionType): + # If both the context and the return type are unions, we simplify shared items + # e.g. T | None <: int | None => T <: int + # since the former would infer T <: int | None. + # whereas the latter would infer the more precise T <: int. + new_ret = [val for val in proper_ret.items if val not in erased_ctx.items] + new_ctx = [val for val in erased_ctx.items if val not in proper_ret.items] + proper_ret = make_simplified_union(new_ret) + erased_ctx = make_simplified_union(new_ctx) # # TODO: Instead of this hack and the one below, we need to use outer and # inner contexts at the same time. This is however not easy because of two @@ -2033,100 +2028,10 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - proper_ret = get_proper_type(ret_type) if ( isinstance(proper_ret, TypeVarType) or isinstance(proper_ret, UnionType) and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) - ): - # Another special case: the return type is a type variable. If it's unrestricted, - # we could infer a too general type for the type variable if we use context, - # and this could result in confusing and spurious type errors elsewhere. - # - # So we give up and just use function arguments for type inference, with just two - # exceptions: - # - # 1. If the context is a generic instance type, actually use it as context, as - # this *seems* to usually be the reasonable thing to do. - # - # See also github issues #462 and #360. - # - # 2. If the context is some literal type, we want to "propagate" that information - # down so that we infer a more precise type for literal expressions. For example, - # the expression `3` normally has an inferred type of `builtins.int`: but if it's - # in a literal context like below, we want it to infer `Literal[3]` instead. - # - # def expects_literal(x: Literal[3]) -> None: pass - # def identity(x: T) -> T: return x - # - # expects_literal(identity(3)) # Should type-check - # TODO: we may want to add similar exception if all arguments are lambdas, since - # in this case external context is almost everything we have. - if not is_generic_instance(ctx) and not is_literal_type_like(ctx): - return callable.copy_modified() - args = infer_type_arguments( - callable.variables, ret_type, erased_ctx, skip_unsatisfied=True - ) - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in args: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - return self.apply_generic_arguments( - callable, new_args, error_context, skip_unsatisfied=True - ) - - def _infer_constraints_from_context( - self, callee: CallableType, error_context: Context - ) -> list[Constraint]: - """Unify callable return type to type context to infer type vars. - - For example, if the return type is set[t] where 't' is a type variable - of callable, and if the context is set[int], return callable modified - by substituting 't' with 'int'. - """ - ctx = self.type_context[-1] - if not ctx: - return [] - # if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # # If both the context and the return type are optional, unwrap the optional, - # # since in 99% cases this is what a user expects. In other words, we replace - # # Optional[T] <: Optional[int] - # # with - # # T <: int - # # while the former would infer T <: Optional[int]. - # ret_type = remove_optional(ret_type) - # erased_ctx = remove_optional(erased_ctx) - # # - # # TODO: Instead of this hack and the one below, we need to use outer and - # # inner contexts at the same time. This is however not easy because of two - # # reasons: - # # * We need to support constraints like [1 <: 2, 2 <: X], i.e. with variables - # # on both sides. (This is not too hard.) - # # * We need to update all the inference "infrastructure", so that all - # # variables in an expression are inferred at the same time. - # # (And this is hard, also we need to be careful with lambdas that require - # # two passes.) - ret_type = callee.ret_type - if isinstance(ret_type, UnionType) and isinstance(ctx, UnionType): - # If both the context and the return type are unions, we simplify shared items - # e.g. T | None <: int | None => T <: int - # since the former would infer T <: int | None. - # whereas the latter would infer the more precise T <: int. - - new_ret = [val for val in ret_type.items if val not in ctx.items] - new_ctx = [val for val in ctx.items if val not in ret_type.items] - ret_type = make_simplified_union(new_ret) - ctx = make_simplified_union(new_ctx) - - proper_ret = get_proper_type(ret_type) - if isinstance(proper_ret, TypeVarType) or ( - isinstance(proper_ret, UnionType) - and all(isinstance(get_proper_type(u), TypeVarType) for u in proper_ret.items) ): # Another special case: the return type is a type variable. If it's unrestricted, # we could infer a too general type for the type variable if we use context, @@ -2153,14 +2058,7 @@ def _infer_constraints_from_context( # in this case external context is almost everything we have. if not is_generic_instance(ctx) and not is_literal_type_like(ctx): return [] - - # The return type may have references to type metavariables that - # we are inferring right now. We must consider them as indeterminate - # and they are not potential results; thus we replace them with the - # special ErasedType type. On the other hand, class type variables are - # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) - constraints = infer_constraints(ret_type, erased_ctx, SUBTYPE_OF) + constraints = infer_constraints(proper_ret, erased_ctx, SUBTYPE_OF) return constraints def infer_function_type_arguments( @@ -2200,17 +2098,25 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args, _ = infer_function_type_arguments( - callee_type, - pass1_args, - arg_kinds, - arg_names, - formal_to_actual, - context=self.argument_infer_context(), - strict=self.chk.in_checked_function(), - ) - if True: # NEW CODE + outer_constraints = self.infer_constraints_from_context(callee_type, context) + outer_solution = solve_constraints( + callee_type.variables, + outer_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + outer_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in outer_solution[0] + ] + outer_solution = (outer_args, outer_solution[1]) + outer_callee = self.apply_generic_arguments( + callee_type, outer_solution[0], context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) + + # NOTE: inner solution not needed inner_constraints = infer_constraints_for_callable( callee_type, pass1_args, @@ -2219,147 +2125,89 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), ) - - outer_constraints = self._infer_constraints_from_context(callee_type, context) - - outer_solution = solve_constraints( - callee_type.variables, - outer_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - inner_solution = solve_constraints( callee_type.variables, inner_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) + inner_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in inner_solution[0] + ] + inner_solution = (inner_args, inner_solution[1]) + # inner_callee = self.apply_generic_arguments( + # callee_type, + # inner_solution[0], # no filtering here + # context, + # skip_unsatisfied=True, + # ) + # inner_ret_type = get_proper_type(inner_callee.ret_type) + # NOTE: The order of constraints is important here! # solve(outer + inner) and solve(inner + outer) may yield different results. - + # we need to use outer first. joint_solution = solve_constraints( callee_type.variables, outer_constraints + inner_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) - - reverse_joint_solution = solve_constraints( - callee_type.variables, - inner_constraints + outer_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, + joint_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in joint_solution[0] + ] + joint_solution = (joint_args, joint_solution[1]) + joint_callee = self.apply_generic_arguments( + callee_type, joint_solution[0], context, skip_unsatisfied=True ) + joint_ret_type = get_proper_type(joint_callee.ret_type) - target_solution = joint_solution - - # Now, we select the solution to use. - # Note: Since joint uses both outer and inner constraints, - # and solution discovered by joint is also a solution for outer and inner. - # therefore, we can pick either inner or outer as a substitute for joint, - # and then try to solve again using only the inner constraints. + # Now, we select which solution to use. use_joint = True - use_outer = True - use_inner = True - - if True: # compute the outer and target return types. - if True: - outer_callee = self.apply_generic_arguments( - callee_type, outer_solution[0], context, skip_unsatisfied=True - ) - outer_ret_type = get_proper_type(outer_callee.ret_type) - - target_callee = self.apply_generic_arguments( - callee_type, target_solution[0], context, skip_unsatisfied=True - ) - target_ret_type = get_proper_type(target_callee.ret_type) - else: - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in outer_solution[0]: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - outer_callee = self.apply_generic_arguments( - callee_type, new_args, context, skip_unsatisfied=True - ) - outer_ret_type = get_proper_type(outer_callee.ret_type) - - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in target_solution[0]: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - target_callee = self.apply_generic_arguments( - callee_type, new_args, context, skip_unsatisfied=True - ) - target_ret_type = get_proper_type(target_callee.ret_type) + use_outer = False + use_inner = False + + # NOTE: inner solution not needed + # if ( + # # joint constraints failed to produce a complete solution + # None in joint_solution[0] + # # If the inner solution is more concrete than the joint solution, prefer the inner solution. + # or is_subtype(inner_ret_type, joint_ret_type) + # or ( # HACK to fix testLiteralAndGenericWithUnion + # isinstance(inner_ret_type, UnionType) + # and any(is_subtype(val, joint_ret_type) for val in inner_ret_type.items) + # ) + # ): + # use_joint = False + # use_outer = False + # use_inner = True if ( # joint constraints failed to produce a complete solution None in joint_solution[0] - # If the outer solution is more concrete than the joint solution, use the outer solution. - or is_subtype(outer_ret_type, target_ret_type) + # If the outer solution is more concrete than the joint solution, prefer the outer solution. + or is_subtype(outer_ret_type, joint_ret_type) or ( # HACK to fix testLiteralAndGenericWithUnion isinstance(outer_ret_type, UnionType) - and any(is_subtype(val, target_ret_type) for val in outer_ret_type.items) + and any(is_subtype(val, joint_ret_type) for val in outer_ret_type.items) ) ): use_joint = False + use_outer = True + use_inner = False if use_joint: - target_solution = reverse_joint_solution - elif use_outer: - target_solution = outer_solution + inferred_args = joint_solution[0] elif use_inner: - target_solution = inner_solution - else: - raise AssertionError - - if __debug__: - _num = arg_pass_nums - _c0 = inner_constraints - _c1 = outer_constraints - - _s1 = outer_solution[0] - _s2 = inner_solution[0] - _s3 = joint_solution[0] - _s4 = reverse_joint_solution[0] - _t0 = target_solution[0] - - _r1 = outer_ret_type - _r2 = target_ret_type - _y1 = outer_callee - _y2 = target_callee - _u0 = use_outer, use_joint - - if use_joint: - # inferred_args = target_solution[0] - pass + inferred_args = inner_solution[0] elif use_outer: # If we cannot use the joint solution, fallback to outer_solution - inferred_args = target_solution[0] - - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in inferred_args: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - + inferred_args = outer_solution[0] # Don't show errors after we have only used the outer context for inference. # We will use argument context to infer more variables. callee_type = self.apply_generic_arguments( - callee_type, new_args, context, skip_unsatisfied=True + callee_type, inferred_args, context, skip_unsatisfied=True ) if need_refresh: # Argument kinds etc. may have changed due to @@ -2372,20 +2220,28 @@ def infer_function_type_arguments( callee_type.arg_names, lambda i: self.accept(args[i]), ) - inferred_args, _ = infer_function_type_arguments( + + # ??? QUESTION: Do we need to recompute arg_types and pass1_args here??? + # recompute and apply inner solution. + inner_constraints = infer_constraints_for_callable( callee_type, pass1_args, arg_kinds, arg_names, formal_to_actual, context=self.argument_infer_context(), + ) + inner_solution = solve_constraints( + callee_type.variables, + inner_constraints, strict=self.chk.in_checked_function(), + allow_polymorphic=False, ) - elif use_inner: - # inferred_args = inner_solution[0] - pass + inferred_args = inner_solution[0] else: raise RuntimeError("No solution found for function type arguments") + else: # END NEW CODE + pass if 2 in arg_pass_nums: # Second pass of type inference. diff --git a/mypy/infer.py b/mypy/infer.py index 0c3860a1a2fc..cdc43797d3b1 100644 --- a/mypy/infer.py +++ b/mypy/infer.py @@ -13,7 +13,7 @@ ) from mypy.nodes import ArgKind from mypy.solve import solve_constraints -from mypy.types import CallableType, Instance, Type, TypeVarLikeType, UninhabitedType +from mypy.types import CallableType, Instance, Type, TypeVarLikeType class ArgumentInferContext(NamedTuple): @@ -63,9 +63,6 @@ def infer_function_type_arguments( return solve_constraints(type_vars, constraints, strict, allow_polymorphic) -from mypy.constraints import Constraint - - def infer_type_arguments( type_vars: Sequence[TypeVarLikeType], template: Type, @@ -76,9 +73,4 @@ def infer_type_arguments( # Like infer_function_type_arguments, but only match a single type # against a generic type. constraints = infer_constraints(template, actual, SUPERTYPE_OF if is_supertype else SUBTYPE_OF) - - # Not needed?! - for tp in type_vars: - constraints.append(Constraint(tp, SUPERTYPE_OF, UninhabitedType())) - return solve_constraints(type_vars, constraints, skip_unsatisfied=skip_unsatisfied)[0] diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 41cc3cf8ddde..fc9a11f86c82 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -478,6 +478,10 @@ class A: def __contains__(self, x: 'A') -> str: pass [builtins fixtures/bool.pyi] +[case testInWithInvalidArgs] +a = 1 in ([1] + ['x']) # E: List item 0 has incompatible type "str"; expected "int" +[builtins fixtures/list.pyi] + [case testEq] a: A b: bool diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 8d053c786be8..daa4db062c15 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`12, y: S`13) -> Union[T`12, S`13]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`6, y: S`7) -> Union[T`6, S`7]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 73cbe11ae48e..012d53223295 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`27) -> builtins.list[S`27]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`29) -> builtins.list[S`29]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`17) -> builtins.list[S`17]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] From 925a3eab29f597404a1d5943a11c2fae9e0c89f5 Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Tue, 8 Jul 2025 13:47:05 +0200 Subject: [PATCH 09/12] Simplified away early inner_solution computation --- mypy/checkexpr.py | 72 +++++++---------------------- test-data/unit/check-functions.test | 2 +- test-data/unit/check-generics.test | 4 +- 3 files changed, 19 insertions(+), 59 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index f9374ff5be6f..e08b3ef12715 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2099,6 +2099,17 @@ def infer_function_type_arguments( pass1_args.append(arg) if True: # NEW CODE + # compute the inner constraints + inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + + # compute the outer solution outer_constraints = self.infer_constraints_from_context(callee_type, context) outer_solution = solve_constraints( callee_type.variables, @@ -2116,34 +2127,7 @@ def infer_function_type_arguments( ) outer_ret_type = get_proper_type(outer_callee.ret_type) - # NOTE: inner solution not needed - inner_constraints = infer_constraints_for_callable( - callee_type, - pass1_args, - arg_kinds, - arg_names, - formal_to_actual, - context=self.argument_infer_context(), - ) - inner_solution = solve_constraints( - callee_type.variables, - inner_constraints, - strict=self.chk.in_checked_function(), - allow_polymorphic=False, - ) - inner_args = [ - None if has_uninhabited_component(arg) or has_erased_component(arg) else arg - for arg in inner_solution[0] - ] - inner_solution = (inner_args, inner_solution[1]) - # inner_callee = self.apply_generic_arguments( - # callee_type, - # inner_solution[0], # no filtering here - # context, - # skip_unsatisfied=True, - # ) - # inner_ret_type = get_proper_type(inner_callee.ret_type) - + # compute the joint solution using both inner and outer constraints. # NOTE: The order of constraints is important here! # solve(outer + inner) and solve(inner + outer) may yield different results. # we need to use outer first. @@ -2163,27 +2147,7 @@ def infer_function_type_arguments( ) joint_ret_type = get_proper_type(joint_callee.ret_type) - # Now, we select which solution to use. - use_joint = True - use_outer = False - use_inner = False - - # NOTE: inner solution not needed - # if ( - # # joint constraints failed to produce a complete solution - # None in joint_solution[0] - # # If the inner solution is more concrete than the joint solution, prefer the inner solution. - # or is_subtype(inner_ret_type, joint_ret_type) - # or ( # HACK to fix testLiteralAndGenericWithUnion - # isinstance(inner_ret_type, UnionType) - # and any(is_subtype(val, joint_ret_type) for val in inner_ret_type.items) - # ) - # ): - # use_joint = False - # use_outer = False - # use_inner = True - - if ( + if ( # determine which solution to take # joint constraints failed to produce a complete solution None in joint_solution[0] # If the outer solution is more concrete than the joint solution, prefer the outer solution. @@ -2194,14 +2158,12 @@ def infer_function_type_arguments( ) ): use_joint = False - use_outer = True - use_inner = False + else: + use_joint = True if use_joint: inferred_args = joint_solution[0] - elif use_inner: - inferred_args = inner_solution[0] - elif use_outer: + else: # If we cannot use the joint solution, fallback to outer_solution inferred_args = outer_solution[0] # Don't show errors after we have only used the outer context for inference. @@ -2238,8 +2200,6 @@ def infer_function_type_arguments( allow_polymorphic=False, ) inferred_args = inner_solution[0] - else: - raise RuntimeError("No solution found for function type arguments") else: # END NEW CODE pass diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index daa4db062c15..7fa34a398ea0 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -3384,7 +3384,7 @@ def f(x: T, y: S) -> Union[T, S]: ... def g(x: T, y: S) -> Union[T, S]: ... x = [f, g] -reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`6, y: S`7) -> Union[T`6, S`7]]" +reveal_type(x) # N: Revealed type is "builtins.list[def [T, S] (x: T`4, y: S`5) -> Union[T`4, S`5]]" [builtins fixtures/list.pyi] [case testTypeVariableClashErrorMessage] diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 012d53223295..0be9d918c69f 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2929,8 +2929,8 @@ def mix(fs: List[Callable[[S], T]]) -> Callable[[S], List[T]]: def id(__x: U) -> U: ... fs = [id, id, id] -reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`15) -> builtins.list[S`15]" -reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`17) -> builtins.list[S`17]" +reveal_type(mix(fs)) # N: Revealed type is "def [S] (S`11) -> builtins.list[S`11]" +reveal_type(mix([id, id, id])) # N: Revealed type is "def [S] (S`13) -> builtins.list[S`13]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericCurry] From 76f357154c0d0f6b6e6fb649b8da2aba5e3eac2a Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Wed, 9 Jul 2025 12:09:42 +0200 Subject: [PATCH 10/12] added testLiteralMappingContext --- test-data/unit/check-literal.test | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbb..b30a56bb5fb6 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1206,6 +1206,13 @@ reveal_type(a) # N: Revealed type is "builtins.dict[builtins.str, builtins.int] [builtins fixtures/dict.pyi] [out] +[case testLiteralMappingContext] +from typing import Mapping, Literal + +x: Mapping[str, Literal["sum", "mean", "max", "min"]] = {"x": "sum"} + +[builtins fixtures/dict.pyi] + [case testLiteralInferredInOverloadContextBasic] from typing import Literal, overload From 55b3bac64230e2502a278922e6c5ef95b4c3aadb Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Wed, 9 Jul 2025 13:14:43 +0200 Subject: [PATCH 11/12] replace one hack with another --- mypy/checkexpr.py | 20 +++++++++++++++----- mypy/constraints.py | 2 ++ test-data/unit/check-generics.test | 10 +++++----- test-data/unit/check-varargs.test | 6 +++--- 4 files changed, 25 insertions(+), 13 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index e08b3ef12715..afd7ab359018 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2108,6 +2108,19 @@ def infer_function_type_arguments( formal_to_actual, context=self.argument_infer_context(), ) + # HACK: convert "Literal?" constraints to their non-literal versions. + inner_constraints = [ + Constraint( + c.original_type_var, + c.op, + ( + c.target.copy_modified(last_known_value=None) + if isinstance(c.target, Instance) + else c.target + ), + ) + for c in inner_constraints + ] # compute the outer solution outer_constraints = self.infer_constraints_from_context(callee_type, context) @@ -2131,9 +2144,10 @@ def infer_function_type_arguments( # NOTE: The order of constraints is important here! # solve(outer + inner) and solve(inner + outer) may yield different results. # we need to use outer first. + joint_constraints = outer_constraints + inner_constraints joint_solution = solve_constraints( callee_type.variables, - outer_constraints + inner_constraints, + joint_constraints, strict=self.chk.in_checked_function(), allow_polymorphic=False, ) @@ -2152,10 +2166,6 @@ def infer_function_type_arguments( None in joint_solution[0] # If the outer solution is more concrete than the joint solution, prefer the outer solution. or is_subtype(outer_ret_type, joint_ret_type) - or ( # HACK to fix testLiteralAndGenericWithUnion - isinstance(outer_ret_type, UnionType) - and any(is_subtype(val, joint_ret_type) for val in outer_ret_type.items) - ) ): use_joint = False else: diff --git a/mypy/constraints.py b/mypy/constraints.py index 9eeea3cb2c26..dee959539927 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -77,11 +77,13 @@ class Constraint: """ type_var: TypeVarId + original_type_var: TypeVarLikeType op = 0 # SUBTYPE_OF or SUPERTYPE_OF target: Type def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.type_var = type_var.id + self.original_type_var = type_var self.op = op # TODO: should we add "assert not isinstance(target, UnpackType)"? # UnpackType is a synthetic type, and is never valid as a constraint target. diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 0be9d918c69f..ed1acbddc8d3 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2998,7 +2998,7 @@ def lift(f: F[T]) -> F[Optional[T]]: ... def g(x: T) -> T: return x -reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +reveal_type(lift(g)) # N: Revealed type is "__main__.F[Union[T`-1, None]]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrder] @@ -3198,11 +3198,11 @@ def dec(f: Callable[P, Callable[[T], S]]) -> Callable[Concatenate[T, P], S]: ... def id() -> Callable[[U], U]: ... def either(x: U) -> Callable[[U], U]: ... def pair(x: U) -> Callable[[V], Tuple[V, U]]: ... -reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> T`3" -reveal_type(dec(either)) # N: Revealed type is "def [T] (T`6, x: T`6) -> T`6" -reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`9, x: U`-1) -> tuple[T`9, U`-1]" +reveal_type(dec(id)) # N: Revealed type is "def (U`-1) -> U`-1" +reveal_type(dec(either)) # N: Revealed type is "def [T] (T`7, x: T`7) -> T`7" +reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`10, x: U`-1) -> tuple[T`10, U`-1]" # This is counter-intuitive but looks correct, dec matches itself only if P can be empty -reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`13, f: def () -> def (T`13) -> S`14) -> S`14" +reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`14, f: def () -> def (T`14) -> S`15) -> S`15" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericParamSpecVsParamSpec] diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index 680021a166f2..def03f5f3ec1 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -629,9 +629,9 @@ from typing import TypeVar T = TypeVar('T') def f(*args: T) -> T: ... -reveal_type(f(*(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" +reveal_type(f(*(1, None))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[builtins.int, None]" [builtins fixtures/tuple.pyi] From d7a5d080a679e691ce0ae869b0960396cacbc7ff Mon Sep 17 00:00:00 2001 From: Randolf Scholz Date: Wed, 9 Jul 2025 13:18:30 +0200 Subject: [PATCH 12/12] fixed use get_proper_type before isinstance check --- mypy/checkexpr.py | 26 ++++++++++++++------------ 1 file changed, 14 insertions(+), 12 deletions(-) diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index afd7ab359018..1eac644b79f5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -2100,7 +2100,7 @@ def infer_function_type_arguments( if True: # NEW CODE # compute the inner constraints - inner_constraints = infer_constraints_for_callable( + _inner_constraints = infer_constraints_for_callable( callee_type, pass1_args, arg_kinds, @@ -2109,18 +2109,20 @@ def infer_function_type_arguments( context=self.argument_infer_context(), ) # HACK: convert "Literal?" constraints to their non-literal versions. - inner_constraints = [ - Constraint( - c.original_type_var, - c.op, - ( - c.target.copy_modified(last_known_value=None) - if isinstance(c.target, Instance) - else c.target - ), + inner_constraints: list[Constraint] = [] + for constraint in _inner_constraints: + target = get_proper_type(constraint.target) + inner_constraints.append( + Constraint( + constraint.original_type_var, + constraint.op, + ( + target.copy_modified(last_known_value=None) + if isinstance(target, Instance) + else target + ), + ) ) - for c in inner_constraints - ] # compute the outer solution outer_constraints = self.infer_constraints_from_context(callee_type, context)