diff --git a/mypy/checkexpr.py b/mypy/checkexpr.py index 8223ccfe4ca0..1eac644b79f5 100644 --- a/mypy/checkexpr.py +++ b/mypy/checkexpr.py @@ -18,6 +18,12 @@ from mypy.checker_shared import ExpressionCheckerSharedApi from mypy.checkmember import analyze_member_access, has_operator from mypy.checkstrformat import StringFormatterChecker +from mypy.constraints import ( + SUBTYPE_OF, + Constraint, + infer_constraints, + infer_constraints_for_callable, +) from mypy.erasetype import erase_type, remove_instance_last_known_values, replace_meta_vars from mypy.errors import ErrorWatcher, report_internal_error from mypy.expandtype import ( @@ -26,7 +32,7 @@ freshen_all_functions_type_vars, freshen_function_type_vars, ) -from mypy.infer import ArgumentInferContext, infer_function_type_arguments, infer_type_arguments +from mypy.infer import ArgumentInferContext, infer_function_type_arguments from mypy.literals import literal from mypy.maptype import map_instance_to_supertype from mypy.meet import is_overlapping_types, narrow_declared_type @@ -110,6 +116,7 @@ Plugin, ) from mypy.semanal_enum import ENUM_BASES +from mypy.solve import solve_constraints from mypy.state import state from mypy.subtypes import ( find_member, @@ -191,12 +198,7 @@ is_named_instance, split_with_prefix_and_suffix, ) -from mypy.types_utils import ( - is_generic_instance, - is_overlapping_none, - is_self_type_like, - remove_optional, -) +from mypy.types_utils import is_generic_instance, is_self_type_like, remove_optional from mypy.typestate import type_state from mypy.typevars import fill_typevars from mypy.util import split_module_names @@ -1774,18 +1776,6 @@ def check_callable_call( isinstance(v, (ParamSpecType, TypeVarTupleType)) for v in callee.variables ) callee = freshen_function_type_vars(callee) - callee = self.infer_function_type_arguments_using_context(callee, context) - if need_refresh: - # Argument kinds etc. may have changed due to - # ParamSpec or TypeVarTuple variables being replaced with an arbitrary - # number of arguments; recalculate actual-to-formal map - formal_to_actual = map_actuals_to_formals( - arg_kinds, - arg_names, - callee.arg_kinds, - callee.arg_names, - lambda i: self.accept(args[i]), - ) callee = self.infer_function_type_arguments( callee, args, arg_kinds, arg_names, formal_to_actual, need_refresh, context ) @@ -2000,9 +1990,9 @@ def infer_arg_types_in_context( assert all(tp is not None for tp in res) return cast(list[Type], res) - def infer_function_type_arguments_using_context( - self, callable: CallableType, error_context: Context - ) -> CallableType: + def infer_constraints_from_context( + self, callee: CallableType, error_context: Context + ) -> list[Constraint]: """Unify callable return type to type context to infer type vars. For example, if the return type is set[t] where 't' is a type variable @@ -2011,23 +2001,23 @@ def infer_function_type_arguments_using_context( """ ctx = self.type_context[-1] if not ctx: - return callable + return [] # The return type may have references to type metavariables that # we are inferring right now. We must consider them as indeterminate # and they are not potential results; thus we replace them with the # special ErasedType type. On the other hand, class type variables are # valid results. - erased_ctx = replace_meta_vars(ctx, ErasedType()) - ret_type = callable.ret_type - if is_overlapping_none(ret_type) and is_overlapping_none(ctx): - # If both the context and the return type are optional, unwrap the optional, - # since in 99% cases this is what a user expects. In other words, we replace - # Optional[T] <: Optional[int] - # with - # T <: int - # while the former would infer T <: Optional[int]. - ret_type = remove_optional(ret_type) - erased_ctx = remove_optional(erased_ctx) + erased_ctx = get_proper_type(replace_meta_vars(ctx, ErasedType())) + proper_ret = get_proper_type(callee.ret_type) + if isinstance(proper_ret, UnionType) and isinstance(erased_ctx, UnionType): + # If both the context and the return type are unions, we simplify shared items + # e.g. T | None <: int | None => T <: int + # since the former would infer T <: int | None. + # whereas the latter would infer the more precise T <: int. + new_ret = [val for val in proper_ret.items if val not in erased_ctx.items] + new_ctx = [val for val in erased_ctx.items if val not in proper_ret.items] + proper_ret = make_simplified_union(new_ret) + erased_ctx = make_simplified_union(new_ctx) # # TODO: Instead of this hack and the one below, we need to use outer and # inner contexts at the same time. This is however not easy because of two @@ -2038,7 +2028,6 @@ def infer_function_type_arguments_using_context( # variables in an expression are inferred at the same time. # (And this is hard, also we need to be careful with lambdas that require # two passes.) - proper_ret = get_proper_type(ret_type) if ( isinstance(proper_ret, TypeVarType) or isinstance(proper_ret, UnionType) @@ -2068,22 +2057,9 @@ def infer_function_type_arguments_using_context( # TODO: we may want to add similar exception if all arguments are lambdas, since # in this case external context is almost everything we have. if not is_generic_instance(ctx) and not is_literal_type_like(ctx): - return callable.copy_modified() - args = infer_type_arguments( - callable.variables, ret_type, erased_ctx, skip_unsatisfied=True - ) - # Only substitute non-Uninhabited and non-erased types. - new_args: list[Type | None] = [] - for arg in args: - if has_uninhabited_component(arg) or has_erased_component(arg): - new_args.append(None) - else: - new_args.append(arg) - # Don't show errors after we have only used the outer context for inference. - # We will use argument context to infer more variables. - return self.apply_generic_arguments( - callable, new_args, error_context, skip_unsatisfied=True - ) + return [] + constraints = infer_constraints(proper_ret, erased_ctx, SUBTYPE_OF) + return constraints def infer_function_type_arguments( self, @@ -2122,15 +2098,122 @@ def infer_function_type_arguments( else: pass1_args.append(arg) - inferred_args, _ = infer_function_type_arguments( - callee_type, - pass1_args, - arg_kinds, - arg_names, - formal_to_actual, - context=self.argument_infer_context(), - strict=self.chk.in_checked_function(), - ) + if True: # NEW CODE + # compute the inner constraints + _inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + # HACK: convert "Literal?" constraints to their non-literal versions. + inner_constraints: list[Constraint] = [] + for constraint in _inner_constraints: + target = get_proper_type(constraint.target) + inner_constraints.append( + Constraint( + constraint.original_type_var, + constraint.op, + ( + target.copy_modified(last_known_value=None) + if isinstance(target, Instance) + else target + ), + ) + ) + + # compute the outer solution + outer_constraints = self.infer_constraints_from_context(callee_type, context) + outer_solution = solve_constraints( + callee_type.variables, + outer_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + outer_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in outer_solution[0] + ] + outer_solution = (outer_args, outer_solution[1]) + outer_callee = self.apply_generic_arguments( + callee_type, outer_solution[0], context, skip_unsatisfied=True + ) + outer_ret_type = get_proper_type(outer_callee.ret_type) + + # compute the joint solution using both inner and outer constraints. + # NOTE: The order of constraints is important here! + # solve(outer + inner) and solve(inner + outer) may yield different results. + # we need to use outer first. + joint_constraints = outer_constraints + inner_constraints + joint_solution = solve_constraints( + callee_type.variables, + joint_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + joint_args = [ + None if has_uninhabited_component(arg) or has_erased_component(arg) else arg + for arg in joint_solution[0] + ] + joint_solution = (joint_args, joint_solution[1]) + joint_callee = self.apply_generic_arguments( + callee_type, joint_solution[0], context, skip_unsatisfied=True + ) + joint_ret_type = get_proper_type(joint_callee.ret_type) + + if ( # determine which solution to take + # joint constraints failed to produce a complete solution + None in joint_solution[0] + # If the outer solution is more concrete than the joint solution, prefer the outer solution. + or is_subtype(outer_ret_type, joint_ret_type) + ): + use_joint = False + else: + use_joint = True + + if use_joint: + inferred_args = joint_solution[0] + else: + # If we cannot use the joint solution, fallback to outer_solution + inferred_args = outer_solution[0] + # Don't show errors after we have only used the outer context for inference. + # We will use argument context to infer more variables. + callee_type = self.apply_generic_arguments( + callee_type, inferred_args, context, skip_unsatisfied=True + ) + if need_refresh: + # Argument kinds etc. may have changed due to + # ParamSpec or TypeVarTuple variables being replaced with an arbitrary + # number of arguments; recalculate actual-to-formal map + formal_to_actual = map_actuals_to_formals( + arg_kinds, + arg_names, + callee_type.arg_kinds, + callee_type.arg_names, + lambda i: self.accept(args[i]), + ) + + # ??? QUESTION: Do we need to recompute arg_types and pass1_args here??? + # recompute and apply inner solution. + inner_constraints = infer_constraints_for_callable( + callee_type, + pass1_args, + arg_kinds, + arg_names, + formal_to_actual, + context=self.argument_infer_context(), + ) + inner_solution = solve_constraints( + callee_type.variables, + inner_constraints, + strict=self.chk.in_checked_function(), + allow_polymorphic=False, + ) + inferred_args = inner_solution[0] + else: # END NEW CODE + pass if 2 in arg_pass_nums: # Second pass of type inference. diff --git a/mypy/constraints.py b/mypy/constraints.py index 9eeea3cb2c26..dee959539927 100644 --- a/mypy/constraints.py +++ b/mypy/constraints.py @@ -77,11 +77,13 @@ class Constraint: """ type_var: TypeVarId + original_type_var: TypeVarLikeType op = 0 # SUBTYPE_OF or SUPERTYPE_OF target: Type def __init__(self, type_var: TypeVarLikeType, op: int, target: Type) -> None: self.type_var = type_var.id + self.original_type_var = type_var self.op = op # TODO: should we add "assert not isinstance(target, UnpackType)"? # UnpackType is a synthetic type, and is never valid as a constraint target. diff --git a/test-data/unit/check-expressions.test b/test-data/unit/check-expressions.test index 33271a3cc04c..fc9a11f86c82 100644 --- a/test-data/unit/check-expressions.test +++ b/test-data/unit/check-expressions.test @@ -308,6 +308,24 @@ main:5: error: Unsupported operand types for ^ ("A" and "A") main:6: error: Unsupported operand types for << ("A" and "B") main:7: error: Unsupported operand types for >> ("A" and "A") +[case testBinaryOperatorContext] +from typing import TypeVar, Generic, Iterable, Iterator, Union + +T = TypeVar("T") +S = TypeVar("S") + +class Vec(Generic[T]): + def __init__(self, iterable: Iterable[T], /) -> None: ... + def __iter__(self) -> Iterator[T]: yield from [] + def __add__(self, value: "Vec[S]", /) -> "Vec[Union[S, T]]": return Vec([]) + +def fmt(arg: Iterable[Union[int, str]]) -> None: ... + +l1: Vec[int] = Vec([1]) +l2: Vec[int] = Vec([1]) +fmt(l1 + l2) +[builtins fixtures/list.pyi] + [case testBooleanAndOr] a: A b: bool diff --git a/test-data/unit/check-generics.test b/test-data/unit/check-generics.test index 0be9d918c69f..ed1acbddc8d3 100644 --- a/test-data/unit/check-generics.test +++ b/test-data/unit/check-generics.test @@ -2998,7 +2998,7 @@ def lift(f: F[T]) -> F[Optional[T]]: ... def g(x: T) -> T: return x -reveal_type(lift(g)) # N: Revealed type is "def [T] (Union[T`1, None]) -> Union[T`1, None]" +reveal_type(lift(g)) # N: Revealed type is "__main__.F[Union[T`-1, None]]" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericSplitOrder] @@ -3198,11 +3198,11 @@ def dec(f: Callable[P, Callable[[T], S]]) -> Callable[Concatenate[T, P], S]: ... def id() -> Callable[[U], U]: ... def either(x: U) -> Callable[[U], U]: ... def pair(x: U) -> Callable[[V], Tuple[V, U]]: ... -reveal_type(dec(id)) # N: Revealed type is "def [T] (T`3) -> T`3" -reveal_type(dec(either)) # N: Revealed type is "def [T] (T`6, x: T`6) -> T`6" -reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`9, x: U`-1) -> tuple[T`9, U`-1]" +reveal_type(dec(id)) # N: Revealed type is "def (U`-1) -> U`-1" +reveal_type(dec(either)) # N: Revealed type is "def [T] (T`7, x: T`7) -> T`7" +reveal_type(dec(pair)) # N: Revealed type is "def [T, U] (T`10, x: U`-1) -> tuple[T`10, U`-1]" # This is counter-intuitive but looks correct, dec matches itself only if P can be empty -reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`13, f: def () -> def (T`13) -> S`14) -> S`14" +reveal_type(dec(dec)) # N: Revealed type is "def [T, S] (T`14, f: def () -> def (T`14) -> S`15) -> S`15" [builtins fixtures/list.pyi] [case testInferenceAgainstGenericParamSpecVsParamSpec] diff --git a/test-data/unit/check-literal.test b/test-data/unit/check-literal.test index 3c9290b8dbbb..b30a56bb5fb6 100644 --- a/test-data/unit/check-literal.test +++ b/test-data/unit/check-literal.test @@ -1206,6 +1206,13 @@ reveal_type(a) # N: Revealed type is "builtins.dict[builtins.str, builtins.int] [builtins fixtures/dict.pyi] [out] +[case testLiteralMappingContext] +from typing import Mapping, Literal + +x: Mapping[str, Literal["sum", "mean", "max", "min"]] = {"x": "sum"} + +[builtins fixtures/dict.pyi] + [case testLiteralInferredInOverloadContextBasic] from typing import Literal, overload diff --git a/test-data/unit/check-recursive-types.test b/test-data/unit/check-recursive-types.test index 7ed5ea53c27e..2b22bef99433 100644 --- a/test-data/unit/check-recursive-types.test +++ b/test-data/unit/check-recursive-types.test @@ -285,21 +285,33 @@ if isinstance(b[0], Sequence): [case testRecursiveAliasWithRecursiveInstance] from typing import Sequence, Union, TypeVar -class A: ... T = TypeVar("T") Nested = Sequence[Union[T, Nested[T]]] +def join(a: T, b: T) -> T: ... + +class A: ... class B(Sequence[B]): ... a: Nested[A] aa: Nested[A] b: B + a = b # OK +reveal_type(a) # N: Revealed type is "__main__.B" + a = [[b]] # OK +reveal_type(a) # N: Revealed type is "builtins.list[builtins.list[__main__.B]]" + b = aa # E: Incompatible types in assignment (expression has type "Nested[A]", variable has type "B") +reveal_type(b) # N: Revealed type is "__main__.B" + +reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" +reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[typing.Sequence[__main__.B]]" + +def test(a: Nested[A], b: B) -> None: + reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" + reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -def join(a: T, b: T) -> T: ... -reveal_type(join(a, b)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" -reveal_type(join(b, a)) # N: Revealed type is "typing.Sequence[Union[__main__.A, typing.Sequence[Union[__main__.A, ...]]]]" [builtins fixtures/isinstancelist.pyi] [case testRecursiveAliasWithRecursiveInstanceInference] diff --git a/test-data/unit/check-varargs.test b/test-data/unit/check-varargs.test index 680021a166f2..def03f5f3ec1 100644 --- a/test-data/unit/check-varargs.test +++ b/test-data/unit/check-varargs.test @@ -629,9 +629,9 @@ from typing import TypeVar T = TypeVar('T') def f(*args: T) -> T: ... -reveal_type(f(*(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[Literal[1]?, None]" -reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[Literal[1]?, None]" +reveal_type(f(*(1, None))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(None, 1))) # N: Revealed type is "Union[builtins.int, None]" +reveal_type(f(1, *(1, None))) # N: Revealed type is "Union[builtins.int, None]" [builtins fixtures/tuple.pyi]