Skip to content

Commit 0291ec9

Browse files
authored
Better support for variadic calls and indexing (#16131)
This improves support for two features that were supported but only partially: variadic calls, and variadic indexing. Some notes: * I did not add dedicated support for slicing of tuples with homogeneous variadic items (except for cases covered by TypeVarTuple support, i.e. those not involving splitting of variadic item). This is tricky and it is not clear what cases people actually want. I left a TODO for this. * I prohibit multiple variadic items in a call expression. Technically, we can support some situations involving these, but this is tricky, and prohibiting this would be in the "spirit" of the PEP, IMO. * I may have still missed some cases for the calls, since there are so many options. If you have ideas for additional test cases, please let me know. * It was necessary to fix overload ambiguity logic to make some tests pass. This goes beyond TypeVarTuple support, but I think this is a correct change.
1 parent d25d680 commit 0291ec9

File tree

8 files changed

+306
-57
lines changed

8 files changed

+306
-57
lines changed

mypy/checkexpr.py

Lines changed: 132 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1640,6 +1640,27 @@ def check_callable_call(
16401640
callee.type_object().name, abstract_attributes, context
16411641
)
16421642

1643+
var_arg = callee.var_arg()
1644+
if var_arg and isinstance(var_arg.typ, UnpackType):
1645+
# It is hard to support multiple variadic unpacks (except for old-style *args: int),
1646+
# fail gracefully to avoid crashes later.
1647+
seen_unpack = False
1648+
for arg, arg_kind in zip(args, arg_kinds):
1649+
if arg_kind != ARG_STAR:
1650+
continue
1651+
arg_type = get_proper_type(self.accept(arg))
1652+
if not isinstance(arg_type, TupleType) or any(
1653+
isinstance(t, UnpackType) for t in arg_type.items
1654+
):
1655+
if seen_unpack:
1656+
self.msg.fail(
1657+
"Passing multiple variadic unpacks in a call is not supported",
1658+
context,
1659+
code=codes.CALL_ARG,
1660+
)
1661+
return AnyType(TypeOfAny.from_error), callee
1662+
seen_unpack = True
1663+
16431664
formal_to_actual = map_actuals_to_formals(
16441665
arg_kinds,
16451666
arg_names,
@@ -2405,7 +2426,7 @@ def check_argument_types(
24052426
]
24062427
actual_kinds = [nodes.ARG_STAR] + [nodes.ARG_POS] * (len(actuals) - 1)
24072428

2408-
# TODO: can we really assert this? What if formal is just plain Unpack[Ts]?
2429+
# If we got here, the callee was previously inferred to have a suffix.
24092430
assert isinstance(orig_callee_arg_type, UnpackType)
24102431
assert isinstance(orig_callee_arg_type.type, ProperType) and isinstance(
24112432
orig_callee_arg_type.type, TupleType
@@ -2431,22 +2452,29 @@ def check_argument_types(
24312452
inner_unpack = unpacked_type.items[inner_unpack_index]
24322453
assert isinstance(inner_unpack, UnpackType)
24332454
inner_unpacked_type = get_proper_type(inner_unpack.type)
2434-
# We assume heterogenous tuples are desugared earlier
2435-
assert isinstance(inner_unpacked_type, Instance)
2436-
assert inner_unpacked_type.type.fullname == "builtins.tuple"
2437-
callee_arg_types = (
2438-
unpacked_type.items[:inner_unpack_index]
2439-
+ [inner_unpacked_type.args[0]]
2440-
* (len(actuals) - len(unpacked_type.items) + 1)
2441-
+ unpacked_type.items[inner_unpack_index + 1 :]
2442-
)
2443-
callee_arg_kinds = [ARG_POS] * len(actuals)
2455+
if isinstance(inner_unpacked_type, TypeVarTupleType):
2456+
# This branch mimics the expanded_tuple case above but for
2457+
# the case where caller passed a single * unpacked tuple argument.
2458+
callee_arg_types = unpacked_type.items
2459+
callee_arg_kinds = [
2460+
ARG_POS if i != inner_unpack_index else ARG_STAR
2461+
for i in range(len(unpacked_type.items))
2462+
]
2463+
else:
2464+
# We assume heterogeneous tuples are desugared earlier.
2465+
assert isinstance(inner_unpacked_type, Instance)
2466+
assert inner_unpacked_type.type.fullname == "builtins.tuple"
2467+
callee_arg_types = (
2468+
unpacked_type.items[:inner_unpack_index]
2469+
+ [inner_unpacked_type.args[0]]
2470+
* (len(actuals) - len(unpacked_type.items) + 1)
2471+
+ unpacked_type.items[inner_unpack_index + 1 :]
2472+
)
2473+
callee_arg_kinds = [ARG_POS] * len(actuals)
24442474
elif isinstance(unpacked_type, TypeVarTupleType):
24452475
callee_arg_types = [orig_callee_arg_type]
24462476
callee_arg_kinds = [ARG_STAR]
24472477
else:
2448-
# TODO: Any and Never can appear in Unpack (as a result of user error),
2449-
# fail gracefully here and elsewhere (and/or normalize them away).
24502478
assert isinstance(unpacked_type, Instance)
24512479
assert unpacked_type.type.fullname == "builtins.tuple"
24522480
callee_arg_types = [unpacked_type.args[0]] * len(actuals)
@@ -2458,8 +2486,10 @@ def check_argument_types(
24582486
assert len(actual_types) == len(actuals) == len(actual_kinds)
24592487

24602488
if len(callee_arg_types) != len(actual_types):
2461-
# TODO: Improve error message
2462-
self.chk.fail("Invalid number of arguments", context)
2489+
if len(actual_types) > len(callee_arg_types):
2490+
self.chk.msg.too_many_arguments(callee, context)
2491+
else:
2492+
self.chk.msg.too_few_arguments(callee, context, None)
24632493
continue
24642494

24652495
assert len(callee_arg_types) == len(actual_types)
@@ -2764,11 +2794,17 @@ def infer_overload_return_type(
27642794
)
27652795
is_match = not w.has_new_errors()
27662796
if is_match:
2767-
# Return early if possible; otherwise record info so we can
2797+
# Return early if possible; otherwise record info, so we can
27682798
# check for ambiguity due to 'Any' below.
27692799
if not args_contain_any:
27702800
return ret_type, infer_type
2771-
matches.append(typ)
2801+
p_infer_type = get_proper_type(infer_type)
2802+
if isinstance(p_infer_type, CallableType):
2803+
# Prefer inferred types if possible, this will avoid false triggers for
2804+
# Any-ambiguity caused by arguments with Any passed to generic overloads.
2805+
matches.append(p_infer_type)
2806+
else:
2807+
matches.append(typ)
27722808
return_types.append(ret_type)
27732809
inferred_types.append(infer_type)
27742810
type_maps.append(m)
@@ -4109,6 +4145,12 @@ def visit_index_with_type(
41094145
# Visit the index, just to make sure we have a type for it available
41104146
self.accept(index)
41114147

4148+
if isinstance(left_type, TupleType) and any(
4149+
isinstance(it, UnpackType) for it in left_type.items
4150+
):
4151+
# Normalize variadic tuples for consistency.
4152+
left_type = expand_type(left_type, {})
4153+
41124154
if isinstance(left_type, UnionType):
41134155
original_type = original_type or left_type
41144156
# Don't combine literal types, since we may need them for type narrowing.
@@ -4129,12 +4171,15 @@ def visit_index_with_type(
41294171
if ns is not None:
41304172
out = []
41314173
for n in ns:
4132-
if n < 0:
4133-
n += len(left_type.items)
4134-
if 0 <= n < len(left_type.items):
4135-
out.append(left_type.items[n])
4174+
item = self.visit_tuple_index_helper(left_type, n)
4175+
if item is not None:
4176+
out.append(item)
41364177
else:
41374178
self.chk.fail(message_registry.TUPLE_INDEX_OUT_OF_RANGE, e)
4179+
if any(isinstance(t, UnpackType) for t in left_type.items):
4180+
self.chk.note(
4181+
f"Variadic tuple can have length {left_type.length() - 1}", e
4182+
)
41384183
return AnyType(TypeOfAny.from_error)
41394184
return make_simplified_union(out)
41404185
else:
@@ -4158,6 +4203,46 @@ def visit_index_with_type(
41584203
e.method_type = method_type
41594204
return result
41604205

4206+
def visit_tuple_index_helper(self, left: TupleType, n: int) -> Type | None:
4207+
unpack_index = find_unpack_in_list(left.items)
4208+
if unpack_index is None:
4209+
if n < 0:
4210+
n += len(left.items)
4211+
if 0 <= n < len(left.items):
4212+
return left.items[n]
4213+
return None
4214+
unpack = left.items[unpack_index]
4215+
assert isinstance(unpack, UnpackType)
4216+
unpacked = get_proper_type(unpack.type)
4217+
if isinstance(unpacked, TypeVarTupleType):
4218+
# Usually we say that TypeVarTuple can't be split, be in case of
4219+
# indexing it seems benign to just return the fallback item, similar
4220+
# to what we do when indexing a regular TypeVar.
4221+
middle = unpacked.tuple_fallback.args[0]
4222+
else:
4223+
assert isinstance(unpacked, Instance)
4224+
assert unpacked.type.fullname == "builtins.tuple"
4225+
middle = unpacked.args[0]
4226+
if n >= 0:
4227+
if n < unpack_index:
4228+
return left.items[n]
4229+
if n >= len(left.items) - 1:
4230+
# For tuple[int, *tuple[str, ...], int] we allow either index 0 or 1,
4231+
# since variadic item may have zero items.
4232+
return None
4233+
return UnionType.make_union(
4234+
[middle] + left.items[unpack_index + 1 : n + 2], left.line, left.column
4235+
)
4236+
n += len(left.items)
4237+
if n <= 0:
4238+
# Similar to above, we only allow -1, and -2 for tuple[int, *tuple[str, ...], int]
4239+
return None
4240+
if n > unpack_index:
4241+
return left.items[n]
4242+
return UnionType.make_union(
4243+
left.items[n - 1 : unpack_index] + [middle], left.line, left.column
4244+
)
4245+
41614246
def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Type:
41624247
begin: Sequence[int | None] = [None]
41634248
end: Sequence[int | None] = [None]
@@ -4183,7 +4268,11 @@ def visit_tuple_slice_helper(self, left_type: TupleType, slic: SliceExpr) -> Typ
41834268

41844269
items: list[Type] = []
41854270
for b, e, s in itertools.product(begin, end, stride):
4186-
items.append(left_type.slice(b, e, s))
4271+
item = left_type.slice(b, e, s)
4272+
if item is None:
4273+
self.chk.fail(message_registry.AMBIGUOUS_SLICE_OF_VARIADIC_TUPLE, slic)
4274+
return AnyType(TypeOfAny.from_error)
4275+
items.append(item)
41874276
return make_simplified_union(items)
41884277

41894278
def try_getting_int_literals(self, index: Expression) -> list[int] | None:
@@ -4192,7 +4281,7 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
41924281
Otherwise, returns None.
41934282
41944283
Specifically, this function is guaranteed to return a list with
4195-
one or more ints if one one the following is true:
4284+
one or more ints if one the following is true:
41964285
41974286
1. 'expr' is a IntExpr or a UnaryExpr backed by an IntExpr
41984287
2. 'typ' is a LiteralType containing an int
@@ -4223,11 +4312,30 @@ def try_getting_int_literals(self, index: Expression) -> list[int] | None:
42234312
def nonliteral_tuple_index_helper(self, left_type: TupleType, index: Expression) -> Type:
42244313
self.check_method_call_by_name("__getitem__", left_type, [index], [ARG_POS], context=index)
42254314
# We could return the return type from above, but unions are often better than the join
4226-
union = make_simplified_union(left_type.items)
4315+
union = self.union_tuple_fallback_item(left_type)
42274316
if isinstance(index, SliceExpr):
42284317
return self.chk.named_generic_type("builtins.tuple", [union])
42294318
return union
42304319

4320+
def union_tuple_fallback_item(self, left_type: TupleType) -> Type:
4321+
# TODO: this duplicates logic in typeops.tuple_fallback().
4322+
items = []
4323+
for item in left_type.items:
4324+
if isinstance(item, UnpackType):
4325+
unpacked_type = get_proper_type(item.type)
4326+
if isinstance(unpacked_type, TypeVarTupleType):
4327+
unpacked_type = get_proper_type(unpacked_type.upper_bound)
4328+
if (
4329+
isinstance(unpacked_type, Instance)
4330+
and unpacked_type.type.fullname == "builtins.tuple"
4331+
):
4332+
items.append(unpacked_type.args[0])
4333+
else:
4334+
raise NotImplementedError
4335+
else:
4336+
items.append(item)
4337+
return make_simplified_union(items)
4338+
42314339
def visit_typeddict_index_expr(
42324340
self, td_type: TypedDictType, index: Expression, setitem: bool = False
42334341
) -> Type:

mypy/constraints.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -137,25 +137,38 @@ def infer_constraints_for_callable(
137137
unpack_type = callee.arg_types[i]
138138
assert isinstance(unpack_type, UnpackType)
139139

140-
# In this case we are binding all of the actuals to *args
140+
# In this case we are binding all the actuals to *args,
141141
# and we want a constraint that the typevar tuple being unpacked
142142
# is equal to a type list of all the actuals.
143143
actual_types = []
144+
145+
unpacked_type = get_proper_type(unpack_type.type)
146+
if isinstance(unpacked_type, TypeVarTupleType):
147+
tuple_instance = unpacked_type.tuple_fallback
148+
elif isinstance(unpacked_type, TupleType):
149+
tuple_instance = unpacked_type.partial_fallback
150+
else:
151+
assert False, "mypy bug: unhandled constraint inference case"
152+
144153
for actual in actuals:
145154
actual_arg_type = arg_types[actual]
146155
if actual_arg_type is None:
147156
continue
148157

149-
actual_types.append(
150-
mapper.expand_actual_type(
151-
actual_arg_type,
152-
arg_kinds[actual],
153-
callee.arg_names[i],
154-
callee.arg_kinds[i],
155-
)
158+
expanded_actual = mapper.expand_actual_type(
159+
actual_arg_type, arg_kinds[actual], callee.arg_names[i], callee.arg_kinds[i]
156160
)
157161

158-
unpacked_type = get_proper_type(unpack_type.type)
162+
if arg_kinds[actual] != ARG_STAR or isinstance(
163+
get_proper_type(actual_arg_type), TupleType
164+
):
165+
actual_types.append(expanded_actual)
166+
else:
167+
# If we are expanding an iterable inside * actual, append a homogeneous item instead
168+
actual_types.append(
169+
UnpackType(tuple_instance.copy_modified(args=[expanded_actual]))
170+
)
171+
159172
if isinstance(unpacked_type, TypeVarTupleType):
160173
constraints.append(
161174
Constraint(

mypy/erasetype.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,9 @@ def visit_instance(self, t: Instance) -> ProperType:
8282
# Valid erasure for *Ts is *tuple[Any, ...], not just Any.
8383
if isinstance(tv, TypeVarTupleType):
8484
args.append(
85-
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
85+
UnpackType(
86+
tv.tuple_fallback.copy_modified(args=[AnyType(TypeOfAny.special_form)])
87+
)
8688
)
8789
else:
8890
args.append(AnyType(TypeOfAny.special_form))

mypy/message_registry.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ def with_additional_msg(self, info: str) -> ErrorMessage:
8383
INCOMPATIBLE_TYPES_IN_CAPTURE: Final = ErrorMessage("Incompatible types in capture pattern")
8484
MUST_HAVE_NONE_RETURN_TYPE: Final = ErrorMessage('The return type of "{}" must be None')
8585
TUPLE_INDEX_OUT_OF_RANGE: Final = ErrorMessage("Tuple index out of range")
86+
AMBIGUOUS_SLICE_OF_VARIADIC_TUPLE: Final = ErrorMessage("Ambiguous slice of a variadic tuple")
8687
INVALID_SLICE_INDEX: Final = ErrorMessage("Slice index must be an integer, SupportsIndex or None")
8788
CANNOT_INFER_LAMBDA_TYPE: Final = ErrorMessage("Cannot infer type of lambda")
8889
CANNOT_ACCESS_INIT: Final = (

mypy/types.py

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2416,14 +2416,53 @@ def copy_modified(
24162416
items = self.items
24172417
return TupleType(items, fallback, self.line, self.column)
24182418

2419-
def slice(self, begin: int | None, end: int | None, stride: int | None) -> TupleType:
2420-
return TupleType(
2421-
self.items[begin:end:stride],
2422-
self.partial_fallback,
2423-
self.line,
2424-
self.column,
2425-
self.implicit,
2426-
)
2419+
def slice(self, begin: int | None, end: int | None, stride: int | None) -> TupleType | None:
2420+
if any(isinstance(t, UnpackType) for t in self.items):
2421+
total = len(self.items)
2422+
unpack_index = find_unpack_in_list(self.items)
2423+
assert unpack_index is not None
2424+
if begin is None and end is None:
2425+
# We special-case this to support reversing variadic tuples.
2426+
# General support for slicing is tricky, so we handle only simple cases.
2427+
if stride == -1:
2428+
slice_items = self.items[::-1]
2429+
elif stride is None or stride == 1:
2430+
slice_items = self.items
2431+
else:
2432+
return None
2433+
elif (begin is None or unpack_index >= begin >= 0) and (
2434+
end is not None and unpack_index >= end >= 0
2435+
):
2436+
# Start and end are in the prefix, everything works in this case.
2437+
slice_items = self.items[begin:end:stride]
2438+
elif (begin is not None and unpack_index - total < begin < 0) and (
2439+
end is None or unpack_index - total < end < 0
2440+
):
2441+
# Start and end are in the suffix, everything works in this case.
2442+
slice_items = self.items[begin:end:stride]
2443+
elif (begin is None or unpack_index >= begin >= 0) and (
2444+
end is None or unpack_index - total < end < 0
2445+
):
2446+
# Start in the prefix, end in the suffix, we can support only trivial strides.
2447+
if stride is None or stride == 1:
2448+
slice_items = self.items[begin:end:stride]
2449+
else:
2450+
return None
2451+
elif (begin is not None and unpack_index - total < begin < 0) and (
2452+
end is not None and unpack_index >= end >= 0
2453+
):
2454+
# Start in the suffix, end in the prefix, we can support only trivial strides.
2455+
if stride is None or stride == -1:
2456+
slice_items = self.items[begin:end:stride]
2457+
else:
2458+
return None
2459+
else:
2460+
# TODO: there some additional cases we can support for homogeneous variadic
2461+
# items, we can "eat away" finite number of items.
2462+
return None
2463+
else:
2464+
slice_items = self.items[begin:end:stride]
2465+
return TupleType(slice_items, self.partial_fallback, self.line, self.column, self.implicit)
24272466

24282467

24292468
class TypedDictType(ProperType):

test-data/unit/check-overloading.test

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6501,8 +6501,7 @@ eggs = lambda: 'eggs'
65016501
reveal_type(func(eggs)) # N: Revealed type is "def (builtins.str) -> builtins.str"
65026502

65036503
spam: Callable[..., str] = lambda x, y: 'baz'
6504-
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> Any"
6505-
6504+
reveal_type(func(spam)) # N: Revealed type is "def (*Any, **Any) -> builtins.str"
65066505
[builtins fixtures/paramspec.pyi]
65076506

65086507
[case testGenericOverloadOverlapWithType]

test-data/unit/check-tuples.test

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,7 +1678,6 @@ def zip(*i: Iterable[Any]) -> Iterator[Tuple[Any, ...]]: ...
16781678
def zip(i): ...
16791679

16801680
def g(t: Tuple):
1681-
# Ideally, we'd infer that these are iterators of tuples
1682-
reveal_type(zip(*t)) # N: Revealed type is "typing.Iterator[Any]"
1683-
reveal_type(zip(t)) # N: Revealed type is "typing.Iterator[Any]"
1681+
reveal_type(zip(*t)) # N: Revealed type is "typing.Iterator[builtins.tuple[Any, ...]]"
1682+
reveal_type(zip(t)) # N: Revealed type is "typing.Iterator[Tuple[Any]]"
16841683
[builtins fixtures/tuple.pyi]

0 commit comments

Comments
 (0)