104
104
TupleExpr ,
105
105
TypeInfo ,
106
106
UnaryExpr ,
107
- is_StrExpr_list ,
108
107
)
109
108
from mypy .options import Options as MypyOptions
110
109
from mypy .stubdoc import Sig , find_unique_signatures , parse_all_signatures
129
128
from mypy .types import (
130
129
OVERLOAD_NAMES ,
131
130
TPDICT_NAMES ,
131
+ TYPED_NAMEDTUPLE_NAMES ,
132
132
AnyType ,
133
133
CallableType ,
134
134
Instance ,
@@ -400,10 +400,12 @@ def visit_str_expr(self, node: StrExpr) -> str:
400
400
def visit_index_expr (self , node : IndexExpr ) -> str :
401
401
base = node .base .accept (self )
402
402
index = node .index .accept (self )
403
+ if len (index ) > 2 and index .startswith ("(" ) and index .endswith (")" ):
404
+ index = index [1 :- 1 ]
403
405
return f"{ base } [{ index } ]"
404
406
405
407
def visit_tuple_expr (self , node : TupleExpr ) -> str :
406
- return ", " .join (n .accept (self ) for n in node .items )
408
+ return f"( { ', ' .join (n .accept (self ) for n in node .items )} )"
407
409
408
410
def visit_list_expr (self , node : ListExpr ) -> str :
409
411
return f"[{ ', ' .join (n .accept (self ) for n in node .items )} ]"
@@ -1010,6 +1012,37 @@ def get_base_types(self, cdef: ClassDef) -> list[str]:
1010
1012
elif isinstance (base , IndexExpr ):
1011
1013
p = AliasPrinter (self )
1012
1014
base_types .append (base .accept (p ))
1015
+ elif isinstance (base , CallExpr ):
1016
+ # namedtuple(typename, fields), NamedTuple(typename, fields) calls can
1017
+ # be used as a base class. The first argument is a string literal that
1018
+ # is usually the same as the class name.
1019
+ #
1020
+ # Note:
1021
+ # A call-based named tuple as a base class cannot be safely converted to
1022
+ # a class-based NamedTuple definition because class attributes defined
1023
+ # in the body of the class inheriting from the named tuple call are not
1024
+ # namedtuple fields at runtime.
1025
+ if self .is_namedtuple (base ):
1026
+ nt_fields = self ._get_namedtuple_fields (base )
1027
+ assert isinstance (base .args [0 ], StrExpr )
1028
+ typename = base .args [0 ].value
1029
+ if nt_fields is not None :
1030
+ # A valid namedtuple() call, use NamedTuple() instead with
1031
+ # Incomplete as field types
1032
+ fields_str = ", " .join (f"({ f !r} , { t } )" for f , t in nt_fields )
1033
+ base_types .append (f"NamedTuple({ typename !r} , [{ fields_str } ])" )
1034
+ self .add_typing_import ("NamedTuple" )
1035
+ else :
1036
+ # Invalid namedtuple() call, cannot determine fields
1037
+ base_types .append ("Incomplete" )
1038
+ elif self .is_typed_namedtuple (base ):
1039
+ p = AliasPrinter (self )
1040
+ base_types .append (base .accept (p ))
1041
+ else :
1042
+ # At this point, we don't know what the base class is, so we
1043
+ # just use Incomplete as the base class.
1044
+ base_types .append ("Incomplete" )
1045
+ self .import_tracker .require_name ("Incomplete" )
1013
1046
return base_types
1014
1047
1015
1048
def visit_block (self , o : Block ) -> None :
@@ -1022,8 +1055,11 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
1022
1055
foundl = []
1023
1056
1024
1057
for lvalue in o .lvalues :
1025
- if isinstance (lvalue , NameExpr ) and self .is_namedtuple (o .rvalue ):
1026
- assert isinstance (o .rvalue , CallExpr )
1058
+ if (
1059
+ isinstance (lvalue , NameExpr )
1060
+ and isinstance (o .rvalue , CallExpr )
1061
+ and (self .is_namedtuple (o .rvalue ) or self .is_typed_namedtuple (o .rvalue ))
1062
+ ):
1027
1063
self .process_namedtuple (lvalue , o .rvalue )
1028
1064
continue
1029
1065
if (
@@ -1069,37 +1105,79 @@ def visit_assignment_stmt(self, o: AssignmentStmt) -> None:
1069
1105
if all (foundl ):
1070
1106
self ._state = VAR
1071
1107
1072
- def is_namedtuple (self , expr : Expression ) -> bool :
1073
- if not isinstance (expr , CallExpr ):
1074
- return False
1108
+ def is_namedtuple (self , expr : CallExpr ) -> bool :
1075
1109
callee = expr .callee
1076
- return (isinstance (callee , NameExpr ) and callee .name .endswith ("namedtuple" )) or (
1077
- isinstance (callee , MemberExpr ) and callee .name == "namedtuple"
1110
+ return (
1111
+ isinstance (callee , NameExpr )
1112
+ and (self .refers_to_fullname (callee .name , "collections.namedtuple" ))
1113
+ ) or (
1114
+ isinstance (callee , MemberExpr )
1115
+ and isinstance (callee .expr , NameExpr )
1116
+ and f"{ callee .expr .name } .{ callee .name } " == "collections.namedtuple"
1078
1117
)
1079
1118
1119
+ def is_typed_namedtuple (self , expr : CallExpr ) -> bool :
1120
+ callee = expr .callee
1121
+ return (
1122
+ isinstance (callee , NameExpr )
1123
+ and self .refers_to_fullname (callee .name , TYPED_NAMEDTUPLE_NAMES )
1124
+ ) or (
1125
+ isinstance (callee , MemberExpr )
1126
+ and isinstance (callee .expr , NameExpr )
1127
+ and f"{ callee .expr .name } .{ callee .name } " in TYPED_NAMEDTUPLE_NAMES
1128
+ )
1129
+
1130
+ def _get_namedtuple_fields (self , call : CallExpr ) -> list [tuple [str , str ]] | None :
1131
+ if self .is_namedtuple (call ):
1132
+ fields_arg = call .args [1 ]
1133
+ if isinstance (fields_arg , StrExpr ):
1134
+ field_names = fields_arg .value .replace ("," , " " ).split ()
1135
+ elif isinstance (fields_arg , (ListExpr , TupleExpr )):
1136
+ field_names = []
1137
+ for field in fields_arg .items :
1138
+ if not isinstance (field , StrExpr ):
1139
+ return None
1140
+ field_names .append (field .value )
1141
+ else :
1142
+ return None # Invalid namedtuple fields type
1143
+ if field_names :
1144
+ self .import_tracker .require_name ("Incomplete" )
1145
+ return [(field_name , "Incomplete" ) for field_name in field_names ]
1146
+ elif self .is_typed_namedtuple (call ):
1147
+ fields_arg = call .args [1 ]
1148
+ if not isinstance (fields_arg , (ListExpr , TupleExpr )):
1149
+ return None
1150
+ fields : list [tuple [str , str ]] = []
1151
+ b = AliasPrinter (self )
1152
+ for field in fields_arg .items :
1153
+ if not (isinstance (field , TupleExpr ) and len (field .items ) == 2 ):
1154
+ return None
1155
+ field_name , field_type = field .items
1156
+ if not isinstance (field_name , StrExpr ):
1157
+ return None
1158
+ fields .append ((field_name .value , field_type .accept (b )))
1159
+ return fields
1160
+ else :
1161
+ return None # Not a named tuple call
1162
+
1080
1163
def process_namedtuple (self , lvalue : NameExpr , rvalue : CallExpr ) -> None :
1081
1164
if self ._state != EMPTY :
1082
1165
self .add ("\n " )
1083
- if isinstance (rvalue .args [1 ], StrExpr ):
1084
- items = rvalue .args [1 ].value .replace ("," , " " ).split ()
1085
- elif isinstance (rvalue .args [1 ], (ListExpr , TupleExpr )):
1086
- list_items = rvalue .args [1 ].items
1087
- assert is_StrExpr_list (list_items )
1088
- items = [item .value for item in list_items ]
1089
- else :
1166
+ fields = self ._get_namedtuple_fields (rvalue )
1167
+ if fields is None :
1090
1168
self .add (f"{ self ._indent } { lvalue .name } : Incomplete" )
1091
1169
self .import_tracker .require_name ("Incomplete" )
1092
1170
return
1093
1171
self .import_tracker .require_name ("NamedTuple" )
1094
1172
self .add (f"{ self ._indent } class { lvalue .name } (NamedTuple):" )
1095
- if not items :
1173
+ if len ( fields ) == 0 :
1096
1174
self .add (" ...\n " )
1175
+ self ._state = EMPTY_CLASS
1097
1176
else :
1098
- self .import_tracker .require_name ("Incomplete" )
1099
1177
self .add ("\n " )
1100
- for item in items :
1101
- self .add (f"{ self ._indent } { item } : Incomplete \n " )
1102
- self ._state = CLASS
1178
+ for f_name , f_type in fields :
1179
+ self .add (f"{ self ._indent } { f_name } : { f_type } \n " )
1180
+ self ._state = CLASS
1103
1181
1104
1182
def is_typeddict (self , expr : CallExpr ) -> bool :
1105
1183
callee = expr .callee
0 commit comments