Skip to content

Stubgen: guess return types based on returned values #18116

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 228 additions & 2 deletions mypy/stubgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import os.path
import sys
import traceback
from typing import Final, Iterable, Iterator
from typing import Final, Iterable, Iterator, Optional

import mypy.build
import mypy.mixedtraverser
Expand All @@ -73,43 +73,68 @@
ARG_STAR2,
IS_ABSTRACT,
NOT_ABSTRACT,
AssertTypeExpr,
AssignmentExpr,
AssignmentStmt,
AwaitExpr,
Block,
BytesExpr,
CallExpr,
CastExpr,
ClassDef,
ComparisonExpr,
ComplexExpr,
ConditionalExpr,
Decorator,
DictExpr,
DictionaryComprehension,
EllipsisExpr,
EnumCallExpr,
Expression,
ExpressionStmt,
FloatExpr,
FuncBase,
FuncDef,
GeneratorExpr,
IfStmt,
Import,
ImportAll,
ImportFrom,
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
MemberExpr,
MypyFile,
NamedTupleExpr,
NameExpr,
NewTypeExpr,
OpExpr,
OverloadedFuncDef,
ParamSpecExpr,
PromoteExpr,
RevealExpr,
SetComprehension,
SetExpr,
SliceExpr,
StarExpr,
Statement,
StrExpr,
SuperExpr,
TempNode,
TupleExpr,
TypeAliasExpr,
TypeAliasStmt,
TypeApplication,
TypedDictExpr,
TypeInfo,
TypeVarExpr,
TypeVarTupleExpr,
UnaryExpr,
Var,
YieldExpr,
YieldFromExpr,
)
from mypy.options import Options as MypyOptions
from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY
Expand All @@ -132,6 +157,7 @@
walk_packages,
)
from mypy.traverser import (
all_return_statements,
all_yield_expressions,
has_return_statement,
has_yield_expression,
Expand All @@ -149,7 +175,7 @@
UnboundType,
get_proper_type,
)
from mypy.visitor import NodeVisitor
from mypy.visitor import ExpressionVisitor, NodeVisitor

# Common ways of naming package containing vendored modules.
VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"]
Expand Down Expand Up @@ -455,6 +481,186 @@ def add_ref(self, fullname: str) -> None:
self.refs.add(fullname)


class ExpressionTyper(ExpressionVisitor[Optional[str]]):
containers: set[str | None]

def __init__(self) -> None:
self.containers = set()

def visit_int_expr(self, o: IntExpr) -> str:
return "int"

def visit_str_expr(self, o: StrExpr) -> str:
return "str"

def visit_bytes_expr(self, o: BytesExpr) -> str:
return "bytes"

def visit_float_expr(self, o: FloatExpr) -> str:
return "float"

def visit_complex_expr(self, o: ComplexExpr) -> str:
return "complex"

def visit_comparison_expr(self, o: ComparisonExpr) -> str:
return "bool"

def visit_name_expr(self, o: NameExpr) -> str | None:
if o.name == "True":
return "bool"
elif o.name == "False":
return "bool"
elif o.name == "None":
return "None"
return None

def visit_unary_expr(self, o: UnaryExpr) -> str | None:
if o.op == "not":
return "bool"
return None

def visit_assignment_expr(self, o: AssignmentExpr) -> str | None:
return o.value.accept(self)

def visit_list_expr(self, o: ListExpr) -> str | None:
items: list[str | None] = [item.accept(self) for item in o.items]
if not items:
return None
element_type = items[0]
if element_type is not None and all(item == element_type for item in items):
self.containers.add("List")
return f"List[{element_type}]"
return None

def visit_dict_expr(self, o: DictExpr) -> str | None:
items: list[tuple[str | None, str | None]] = [
((None, None) if key is None else (key.accept(self), value.accept(self)))
for key, value in o.items
]
if not items:
return None
key, value = items[0]
if (
key is not None
and value is not None
and all(k == key and v == value for k, v in items)
):
self.containers.add("Dict")
return f"Dict[{key}, {value}]"
return None

def visit_tuple_expr(self, o: TupleExpr) -> str | None:
items: list[str | None] = [item.accept(self) for item in o.items]
if items and all(item is not None for item in items):
self.containers.add("Tuple")
elements = ", ".join([item for item in items if item is not None])
return f"Tuple[{elements}]"
return None

def visit_set_expr(self, o: SetExpr) -> str | None:
items: list[str | None] = [item.accept(self) for item in o.items]
if not items:
return None
element_type = items[0]
if element_type is not None and all(item == element_type for item in items):
self.containers.add("Set")
return f"Set[{element_type}]"
return None

def visit_ellipsis(self, o: EllipsisExpr) -> None:
return None

def visit_star_expr(self, o: StarExpr) -> None:
return None

def visit_member_expr(self, o: MemberExpr) -> None:
return None

def visit_yield_from_expr(self, o: YieldFromExpr) -> None:
return None

def visit_yield_expr(self, o: YieldExpr) -> None:
return None

def visit_call_expr(self, o: CallExpr) -> None:
return None

def visit_op_expr(self, o: OpExpr) -> None:
return None

def visit_cast_expr(self, o: CastExpr) -> None:
return None

def visit_assert_type_expr(self, o: AssertTypeExpr) -> None:
return None

def visit_reveal_expr(self, o: RevealExpr) -> None:
return None

def visit_super_expr(self, o: SuperExpr) -> None:
return None

def visit_index_expr(self, o: IndexExpr) -> None:
return None

def visit_type_application(self, o: TypeApplication) -> None:
return None

def visit_lambda_expr(self, o: LambdaExpr) -> None:
return None

def visit_list_comprehension(self, o: ListComprehension) -> None:
return None

def visit_set_comprehension(self, o: SetComprehension) -> None:
return None

def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None:
return None

def visit_generator_expr(self, o: GeneratorExpr) -> None:
return None

def visit_slice_expr(self, o: SliceExpr) -> None:
return None

def visit_conditional_expr(self, o: ConditionalExpr) -> None:
return None

def visit_type_var_expr(self, o: TypeVarExpr) -> None:
return None

def visit_paramspec_expr(self, o: ParamSpecExpr) -> None:
return None

def visit_type_var_tuple_expr(self, o: TypeVarTupleExpr) -> None:
return None

def visit_type_alias_expr(self, o: TypeAliasExpr) -> None:
return None

def visit_namedtuple_expr(self, o: NamedTupleExpr) -> None:
return None

def visit_enum_call_expr(self, o: EnumCallExpr) -> None:
return None

def visit_typeddict_expr(self, o: TypedDictExpr) -> None:
return None

def visit_newtype_expr(self, o: NewTypeExpr) -> None:
return None

def visit__promote_expr(self, o: PromoteExpr) -> None:
return None

def visit_await_expr(self, o: AwaitExpr) -> None:
return None

def visit_temp_node(self, o: TempNode) -> None:
return None


class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor):
"""Generate stub text from a mypy AST."""

Expand Down Expand Up @@ -619,6 +825,26 @@ def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None:
return f"{generator_name}[{yield_name}]"
if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
return "None"
if has_return_statement(o) and o.abstract_status == NOT_ABSTRACT:
return_expressions = [ret.expr for ret in all_return_statements(o)]
return_type_visitor = ExpressionTyper()
return_types = [
ret.accept(return_type_visitor) if ret is not None else "None"
for ret in return_expressions
]
if not all(return_types):
return None
return_type_set = set(return_types)
if len(return_type_set) == 2 and "None" in return_type_set:
for name in return_type_visitor.containers:
self.add_name(f"typing.{name}")
return_type_set.remove("None")
inner_type = next(iter(return_type_set))
return f"{inner_type} | None"
elif len(return_type_set) == 1:
for name in return_type_visitor.containers:
self.add_name(f"typing.{name}")
return next(iter(return_type_set))
return None

def _get_func_docstring(self, node: FuncDef) -> str | None:
Expand Down
Loading
Loading