Skip to content

Commit 171e6f8

Browse files
authored
stubgen: Fix call-based namedtuple omitted from class bases (#14680)
Fixes #9901 Fixes #13662 Fix inheriting from a call-based `collections.namedtuple` / `typing.NamedTuple` definition that was omitted from the generated stub. This automatically adds support for the call-based `NamedTuple` in general not only in class bases (Closes #13788). <details> <summary>An example before and after</summary> Input: ```python import collections import typing from collections import namedtuple from typing import NamedTuple CollectionsCall = namedtuple("CollectionsCall", ["x", "y"]) class CollectionsClass(namedtuple("CollectionsClass", ["x", "y"])): def f(self, a): pass class CollectionsDotClass(collections.namedtuple("CollectionsClass", ["x", "y"])): def f(self, a): pass TypingCall = NamedTuple("TypingCall", [("x", int | None), ("y", int)]) class TypingClass(NamedTuple): x: int | None y: str def f(self, a): pass class TypingClassWeird(NamedTuple("TypingClassWeird", [("x", int | None), ("y", str)])): z: float | None def f(self, a): pass class TypingDotClassWeird(typing.NamedTuple("TypingClassWeird", [("x", int | None), ("y", str)])): def f(self, a): pass ``` Output diff (before and after): ```diff diff --git a/before.pyi b/after.pyi index c88530e2c..95ef843b4 100644 --- a/before.pyi +++ b/after.pyi @@ -1,26 +1,29 @@ +import typing from _typeshed import Incomplete from typing_extensions import NamedTuple class CollectionsCall(NamedTuple): x: Incomplete y: Incomplete -class CollectionsClass: +class CollectionsClass(NamedTuple('CollectionsClass', [('x', Incomplete), ('y', Incomplete)])): def f(self, a) -> None: ... -class CollectionsDotClass: +class CollectionsDotClass(NamedTuple('CollectionsClass', [('x', Incomplete), ('y', Incomplete)])): def f(self, a) -> None: ... -TypingCall: Incomplete +class TypingCall(NamedTuple): + x: int | None + y: int class TypingClass(NamedTuple): x: int | None y: str def f(self, a) -> None: ... -class TypingClassWeird: +class TypingClassWeird(NamedTuple('TypingClassWeird', [('x', int | None), ('y', str)])): z: float | None def f(self, a) -> None: ... -class TypingDotClassWeird: +class TypingDotClassWeird(typing.NamedTuple('TypingClassWeird', [('x', int | None), ('y', str)])): def f(self, a) -> None: ... ``` </details>
1 parent d710fdd commit 171e6f8

File tree

2 files changed

+168
-23
lines changed

2 files changed

+168
-23
lines changed

mypy/stubgen.py

Lines changed: 99 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@
104104
TupleExpr,
105105
TypeInfo,
106106
UnaryExpr,
107-
is_StrExpr_list,
108107
)
109108
from mypy.options import Options as MypyOptions
110109
from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures
@@ -129,6 +128,7 @@
129128
from mypy.types import (
130129
OVERLOAD_NAMES,
131130
TPDICT_NAMES,
131+
TYPED_NAMEDTUPLE_NAMES,
132132
AnyType,
133133
CallableType,
134134
Instance,
@@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str:
400400
def visit_index_expr(self, node: IndexExpr) -> str:
401401
base = node.base.accept(self)
402402
index = node.index.accept(self)
403+
if len(index) > 2 and index.startswith("(") and index.endswith(")"):
404+
index = index[1:-1]
403405
return f"{base}[{index}]"
404406

405407
def visit_tuple_expr(self, node: TupleExpr) -> str:
406-
return ", ".join(n.accept(self) for n in node.items)
408+
return f"({', '.join(n.accept(self) for n in node.items)})"
407409

408410
def visit_list_expr(self, node: ListExpr) -> str:
409411
return f"[{', '.join(n.accept(self) for n in node.items)}]"
@@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
10101012
elif isinstance(base, IndexExpr):
10111013
p = AliasPrinter(self)
10121014
base_types.append(base.accept(p))
1015+
elif isinstance(base, CallExpr):
1016+
# namedtuple(typename, fields), NamedTuple(typename, fields) calls can
1017+
# be used as a base class. The first argument is a string literal that
1018+
# is usually the same as the class name.
1019+
#
1020+
# Note:
1021+
# A call-based named tuple as a base class cannot be safely converted to
1022+
# a class-based NamedTuple definition because class attributes defined
1023+
# in the body of the class inheriting from the named tuple call are not
1024+
# namedtuple fields at runtime.
1025+
if self.is_namedtuple(base):
1026+
nt_fields = self._get_namedtuple_fields(base)
1027+
assert isinstance(base.args[0], StrExpr)
1028+
typename = base.args[0].value
1029+
if nt_fields is not None:
1030+
# A valid namedtuple() call, use NamedTuple() instead with
1031+
# Incomplete as field types
1032+
fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields)
1033+
base_types.append(f"NamedTuple({typename!r}, [{fields_str}])")
1034+
self.add_typing_import("NamedTuple")
1035+
else:
1036+
# Invalid namedtuple() call, cannot determine fields
1037+
base_types.append("Incomplete")
1038+
elif self.is_typed_namedtuple(base):
1039+
p = AliasPrinter(self)
1040+
base_types.append(base.accept(p))
1041+
else:
1042+
# At this point, we don't know what the base class is, so we
1043+
# just use Incomplete as the base class.
1044+
base_types.append("Incomplete")
1045+
self.import_tracker.require_name("Incomplete")
10131046
return base_types
10141047

10151048
def visit_block(self, o: Block) -> None:
@@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
10221055
foundl = []
10231056

10241057
for lvalue in o.lvalues:
1025-
if isinstance(lvalue, NameExpr) and self.is_namedtuple(o.rvalue):
1026-
assert isinstance(o.rvalue, CallExpr)
1058+
if (
1059+
isinstance(lvalue, NameExpr)
1060+
and isinstance(o.rvalue, CallExpr)
1061+
and (self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue))
1062+
):
10271063
self.process_namedtuple(lvalue, o.rvalue)
10281064
continue
10291065
if (
@@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
10691105
if all(foundl):
10701106
self._state = VAR
10711107

1072-
def is_namedtuple(self, expr: Expression) -> bool:
1073-
if not isinstance(expr, CallExpr):
1074-
return False
1108+
def is_namedtuple(self, expr: CallExpr) -> bool:
10751109
callee = expr.callee
1076-
return (isinstance(callee, NameExpr) and callee.name.endswith("namedtuple")) or (
1077-
isinstance(callee, MemberExpr) and callee.name == "namedtuple"
1110+
return (
1111+
isinstance(callee, NameExpr)
1112+
and (self.refers_to_fullname(callee.name, "collections.namedtuple"))
1113+
) or (
1114+
isinstance(callee, MemberExpr)
1115+
and isinstance(callee.expr, NameExpr)
1116+
and f"{callee.expr.name}.{callee.name}" == "collections.namedtuple"
10781117
)
10791118

1119+
def is_typed_namedtuple(self, expr: CallExpr) -> bool:
1120+
callee = expr.callee
1121+
return (
1122+
isinstance(callee, NameExpr)
1123+
and self.refers_to_fullname(callee.name, TYPED_NAMEDTUPLE_NAMES)
1124+
) or (
1125+
isinstance(callee, MemberExpr)
1126+
and isinstance(callee.expr, NameExpr)
1127+
and f"{callee.expr.name}.{callee.name}" in TYPED_NAMEDTUPLE_NAMES
1128+
)
1129+
1130+
def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None:
1131+
if self.is_namedtuple(call):
1132+
fields_arg = call.args[1]
1133+
if isinstance(fields_arg, StrExpr):
1134+
field_names = fields_arg.value.replace(",", " ").split()
1135+
elif isinstance(fields_arg, (ListExpr, TupleExpr)):
1136+
field_names = []
1137+
for field in fields_arg.items:
1138+
if not isinstance(field, StrExpr):
1139+
return None
1140+
field_names.append(field.value)
1141+
else:
1142+
return None # Invalid namedtuple fields type
1143+
if field_names:
1144+
self.import_tracker.require_name("Incomplete")
1145+
return [(field_name, "Incomplete") for field_name in field_names]
1146+
elif self.is_typed_namedtuple(call):
1147+
fields_arg = call.args[1]
1148+
if not isinstance(fields_arg, (ListExpr, TupleExpr)):
1149+
return None
1150+
fields: list[tuple[str, str]] = []
1151+
b = AliasPrinter(self)
1152+
for field in fields_arg.items:
1153+
if not (isinstance(field, TupleExpr) and len(field.items) == 2):
1154+
return None
1155+
field_name, field_type = field.items
1156+
if not isinstance(field_name, StrExpr):
1157+
return None
1158+
fields.append((field_name.value, field_type.accept(b)))
1159+
return fields
1160+
else:
1161+
return None # Not a named tuple call
1162+
10801163
def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None:
10811164
if self._state != EMPTY:
10821165
self.add("\n")
1083-
if isinstance(rvalue.args[1], StrExpr):
1084-
items = rvalue.args[1].value.replace(",", " ").split()
1085-
elif isinstance(rvalue.args[1], (ListExpr, TupleExpr)):
1086-
list_items = rvalue.args[1].items
1087-
assert is_StrExpr_list(list_items)
1088-
items = [item.value for item in list_items]
1089-
else:
1166+
fields = self._get_namedtuple_fields(rvalue)
1167+
if fields is None:
10901168
self.add(f"{self._indent}{lvalue.name}: Incomplete")
10911169
self.import_tracker.require_name("Incomplete")
10921170
return
10931171
self.import_tracker.require_name("NamedTuple")
10941172
self.add(f"{self._indent}class {lvalue.name}(NamedTuple):")
1095-
if not items:
1173+
if len(fields) == 0:
10961174
self.add(" ...\n")
1175+
self._state = EMPTY_CLASS
10971176
else:
1098-
self.import_tracker.require_name("Incomplete")
10991177
self.add("\n")
1100-
for item in items:
1101-
self.add(f"{self._indent} {item}: Incomplete\n")
1102-
self._state = CLASS
1178+
for f_name, f_type in fields:
1179+
self.add(f"{self._indent} {f_name}: {f_type}\n")
1180+
self._state = CLASS
11031181

11041182
def is_typeddict(self, expr: CallExpr) -> bool:
11051183
callee = expr.callee

test-data/unit/stubgen.test

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,9 @@ class A:
641641
def _bar(cls) -> None: ...
642642

643643
[case testNamedtuple]
644-
import collections, x
644+
import collections, typing, x
645645
X = collections.namedtuple('X', ['a', 'b'])
646+
Y = typing.NamedTuple('Y', [('a', int), ('b', str)])
646647
[out]
647648
from _typeshed import Incomplete
648649
from typing import NamedTuple
@@ -651,14 +652,21 @@ class X(NamedTuple):
651652
a: Incomplete
652653
b: Incomplete
653654

655+
class Y(NamedTuple):
656+
a: int
657+
b: str
658+
654659
[case testEmptyNamedtuple]
655-
import collections
660+
import collections, typing
656661
X = collections.namedtuple('X', [])
662+
Y = typing.NamedTuple('Y', [])
657663
[out]
658664
from typing import NamedTuple
659665

660666
class X(NamedTuple): ...
661667

668+
class Y(NamedTuple): ...
669+
662670
[case testNamedtupleAltSyntax]
663671
from collections import namedtuple, xx
664672
X = namedtuple('X', 'a b')
@@ -697,8 +705,10 @@ class X(NamedTuple):
697705

698706
[case testNamedtupleWithUnderscore]
699707
from collections import namedtuple as _namedtuple
708+
from typing import NamedTuple as _NamedTuple
700709
def f(): ...
701710
X = _namedtuple('X', 'a b')
711+
Y = _NamedTuple('Y', [('a', int), ('b', str)])
702712
def g(): ...
703713
[out]
704714
from _typeshed import Incomplete
@@ -710,6 +720,10 @@ class X(NamedTuple):
710720
a: Incomplete
711721
b: Incomplete
712722

723+
class Y(NamedTuple):
724+
a: int
725+
b: str
726+
713727
def g() -> None: ...
714728

715729
[case testNamedtupleBaseClass]
@@ -728,10 +742,14 @@ class Y(_X): ...
728742

729743
[case testNamedtupleAltSyntaxFieldsTuples]
730744
from collections import namedtuple, xx
745+
from typing import NamedTuple
731746
X = namedtuple('X', ())
732747
Y = namedtuple('Y', ('a',))
733748
Z = namedtuple('Z', ('a', 'b', 'c', 'd', 'e'))
734749
xx
750+
R = NamedTuple('R', ())
751+
S = NamedTuple('S', (('a', int),))
752+
T = NamedTuple('T', (('a', int), ('b', str)))
735753
[out]
736754
from _typeshed import Incomplete
737755
from typing import NamedTuple
@@ -748,13 +766,62 @@ class Z(NamedTuple):
748766
d: Incomplete
749767
e: Incomplete
750768

769+
class R(NamedTuple): ...
770+
771+
class S(NamedTuple):
772+
a: int
773+
774+
class T(NamedTuple):
775+
a: int
776+
b: str
777+
751778
[case testDynamicNamedTuple]
752779
from collections import namedtuple
780+
from typing import NamedTuple
753781
N = namedtuple('N', ['x', 'y'] + ['z'])
782+
M = NamedTuple('M', [('x', int), ('y', str)] + [('z', float)])
783+
class X(namedtuple('X', ['a', 'b'] + ['c'])): ...
754784
[out]
755785
from _typeshed import Incomplete
756786

757787
N: Incomplete
788+
M: Incomplete
789+
class X(Incomplete): ...
790+
791+
[case testNamedTupleInClassBases]
792+
import collections, typing
793+
from collections import namedtuple
794+
from typing import NamedTuple
795+
class X(namedtuple('X', ['a', 'b'])): ...
796+
class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
797+
class R(collections.namedtuple('R', ['a', 'b'])): ...
798+
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...
799+
[out]
800+
import typing
801+
from _typeshed import Incomplete
802+
from typing import NamedTuple
803+
804+
class X(NamedTuple('X', [('a', Incomplete), ('b', Incomplete)])): ...
805+
class Y(NamedTuple('Y', [('a', int), ('b', str)])): ...
806+
class R(NamedTuple('R', [('a', Incomplete), ('b', Incomplete)])): ...
807+
class S(typing.NamedTuple('S', [('a', int), ('b', str)])): ...
808+
809+
[case testNotNamedTuple]
810+
from not_collections import namedtuple
811+
from not_typing import NamedTuple
812+
from collections import notnamedtuple
813+
from typing import NotNamedTuple
814+
X = namedtuple('X', ['a', 'b'])
815+
Y = notnamedtuple('Y', ['a', 'b'])
816+
Z = NamedTuple('Z', [('a', int), ('b', str)])
817+
W = NotNamedTuple('W', [('a', int), ('b', str)])
818+
[out]
819+
from _typeshed import Incomplete
820+
821+
X: Incomplete
822+
Y: Incomplete
823+
Z: Incomplete
824+
W: Incomplete
758825

759826
[case testArbitraryBaseClass]
760827
import x

0 commit comments

Comments
 (0)