From de4d7d483a14005386b5521e6bf0ef507e0cd374 Mon Sep 17 00:00:00 2001 From: Danny Yang Date: Wed, 6 Nov 2024 21:24:19 -0500 Subject: [PATCH] add simple return type inference --- mypy/stubgen.py | 230 +++++++++++++++++++++++++++++++++++- test-data/unit/stubgen.test | 157 +++++++++++++++++------- 2 files changed, 341 insertions(+), 46 deletions(-) diff --git a/mypy/stubgen.py b/mypy/stubgen.py index fdad5c2ddd89..717154d2c61e 100755 --- a/mypy/stubgen.py +++ b/mypy/stubgen.py @@ -47,7 +47,7 @@ import os.path import sys import traceback -from typing import Final, Iterable, Iterator +from typing import Final, Iterable, Iterator, Optional import mypy.build import mypy.mixedtraverser @@ -73,43 +73,68 @@ ARG_STAR2, IS_ABSTRACT, NOT_ABSTRACT, + AssertTypeExpr, + AssignmentExpr, AssignmentStmt, + AwaitExpr, Block, BytesExpr, CallExpr, + CastExpr, ClassDef, ComparisonExpr, ComplexExpr, + ConditionalExpr, Decorator, DictExpr, + DictionaryComprehension, EllipsisExpr, + EnumCallExpr, Expression, ExpressionStmt, FloatExpr, FuncBase, FuncDef, + GeneratorExpr, IfStmt, Import, ImportAll, ImportFrom, IndexExpr, IntExpr, + LambdaExpr, + ListComprehension, ListExpr, MemberExpr, MypyFile, + NamedTupleExpr, NameExpr, + NewTypeExpr, OpExpr, OverloadedFuncDef, + ParamSpecExpr, + PromoteExpr, + RevealExpr, + SetComprehension, SetExpr, + SliceExpr, StarExpr, Statement, StrExpr, + SuperExpr, TempNode, TupleExpr, + TypeAliasExpr, TypeAliasStmt, + TypeApplication, + TypedDictExpr, TypeInfo, + TypeVarExpr, + TypeVarTupleExpr, UnaryExpr, Var, + YieldExpr, + YieldFromExpr, ) from mypy.options import Options as MypyOptions from mypy.sharedparse import MAGIC_METHODS_POS_ARGS_ONLY @@ -132,6 +157,7 @@ walk_packages, ) from mypy.traverser import ( + all_return_statements, all_yield_expressions, has_return_statement, has_yield_expression, @@ -149,7 +175,7 @@ UnboundType, get_proper_type, ) -from mypy.visitor import NodeVisitor +from mypy.visitor import ExpressionVisitor, NodeVisitor # Common ways of naming package containing vendored modules. VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"] @@ -455,6 +481,186 @@ def add_ref(self, fullname: str) -> None: self.refs.add(fullname) +class ExpressionTyper(ExpressionVisitor[Optional[str]]): + containers: set[str | None] + + def __init__(self) -> None: + self.containers = set() + + def visit_int_expr(self, o: IntExpr) -> str: + return "int" + + def visit_str_expr(self, o: StrExpr) -> str: + return "str" + + def visit_bytes_expr(self, o: BytesExpr) -> str: + return "bytes" + + def visit_float_expr(self, o: FloatExpr) -> str: + return "float" + + def visit_complex_expr(self, o: ComplexExpr) -> str: + return "complex" + + def visit_comparison_expr(self, o: ComparisonExpr) -> str: + return "bool" + + def visit_name_expr(self, o: NameExpr) -> str | None: + if o.name == "True": + return "bool" + elif o.name == "False": + return "bool" + elif o.name == "None": + return "None" + return None + + def visit_unary_expr(self, o: UnaryExpr) -> str | None: + if o.op == "not": + return "bool" + return None + + def visit_assignment_expr(self, o: AssignmentExpr) -> str | None: + return o.value.accept(self) + + def visit_list_expr(self, o: ListExpr) -> str | None: + items: list[str | None] = [item.accept(self) for item in o.items] + if not items: + return None + element_type = items[0] + if element_type is not None and all(item == element_type for item in items): + self.containers.add("List") + return f"List[{element_type}]" + return None + + def visit_dict_expr(self, o: DictExpr) -> str | None: + items: list[tuple[str | None, str | None]] = [ + ((None, None) if key is None else (key.accept(self), value.accept(self))) + for key, value in o.items + ] + if not items: + return None + key, value = items[0] + if ( + key is not None + and value is not None + and all(k == key and v == value for k, v in items) + ): + self.containers.add("Dict") + return f"Dict[{key}, {value}]" + return None + + def visit_tuple_expr(self, o: TupleExpr) -> str | None: + items: list[str | None] = [item.accept(self) for item in o.items] + if items and all(item is not None for item in items): + self.containers.add("Tuple") + elements = ", ".join([item for item in items if item is not None]) + return f"Tuple[{elements}]" + return None + + def visit_set_expr(self, o: SetExpr) -> str | None: + items: list[str | None] = [item.accept(self) for item in o.items] + if not items: + return None + element_type = items[0] + if element_type is not None and all(item == element_type for item in items): + self.containers.add("Set") + return f"Set[{element_type}]" + return None + + def visit_ellipsis(self, o: EllipsisExpr) -> None: + return None + + def visit_star_expr(self, o: StarExpr) -> None: + return None + + def visit_member_expr(self, o: MemberExpr) -> None: + return None + + def visit_yield_from_expr(self, o: YieldFromExpr) -> None: + return None + + def visit_yield_expr(self, o: YieldExpr) -> None: + return None + + def visit_call_expr(self, o: CallExpr) -> None: + return None + + def visit_op_expr(self, o: OpExpr) -> None: + return None + + def visit_cast_expr(self, o: CastExpr) -> None: + return None + + def visit_assert_type_expr(self, o: AssertTypeExpr) -> None: + return None + + def visit_reveal_expr(self, o: RevealExpr) -> None: + return None + + def visit_super_expr(self, o: SuperExpr) -> None: + return None + + def visit_index_expr(self, o: IndexExpr) -> None: + return None + + def visit_type_application(self, o: TypeApplication) -> None: + return None + + def visit_lambda_expr(self, o: LambdaExpr) -> None: + return None + + def visit_list_comprehension(self, o: ListComprehension) -> None: + return None + + def visit_set_comprehension(self, o: SetComprehension) -> None: + return None + + def visit_dictionary_comprehension(self, o: DictionaryComprehension) -> None: + return None + + def visit_generator_expr(self, o: GeneratorExpr) -> None: + return None + + def visit_slice_expr(self, o: SliceExpr) -> None: + return None + + def visit_conditional_expr(self, o: ConditionalExpr) -> None: + return None + + def visit_type_var_expr(self, o: TypeVarExpr) -> None: + return None + + def visit_paramspec_expr(self, o: ParamSpecExpr) -> None: + return None + + def visit_type_var_tuple_expr(self, o: TypeVarTupleExpr) -> None: + return None + + def visit_type_alias_expr(self, o: TypeAliasExpr) -> None: + return None + + def visit_namedtuple_expr(self, o: NamedTupleExpr) -> None: + return None + + def visit_enum_call_expr(self, o: EnumCallExpr) -> None: + return None + + def visit_typeddict_expr(self, o: TypedDictExpr) -> None: + return None + + def visit_newtype_expr(self, o: NewTypeExpr) -> None: + return None + + def visit__promote_expr(self, o: PromoteExpr) -> None: + return None + + def visit_await_expr(self, o: AwaitExpr) -> None: + return None + + def visit_temp_node(self, o: TempNode) -> None: + return None + + class ASTStubGenerator(BaseStubGenerator, mypy.traverser.TraverserVisitor): """Generate stub text from a mypy AST.""" @@ -619,6 +825,26 @@ def _get_func_return(self, o: FuncDef, ctx: FunctionContext) -> str | None: return f"{generator_name}[{yield_name}]" if not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: return "None" + if has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: + return_expressions = [ret.expr for ret in all_return_statements(o)] + return_type_visitor = ExpressionTyper() + return_types = [ + ret.accept(return_type_visitor) if ret is not None else "None" + for ret in return_expressions + ] + if not all(return_types): + return None + return_type_set = set(return_types) + if len(return_type_set) == 2 and "None" in return_type_set: + for name in return_type_visitor.containers: + self.add_name(f"typing.{name}") + return_type_set.remove("None") + inner_type = next(iter(return_type_set)) + return f"{inner_type} | None" + elif len(return_type_set) == 1: + for name in return_type_visitor.containers: + self.add_name(f"typing.{name}") + return next(iter(return_type_set)) return None def _get_func_docstring(self, node: FuncDef) -> str | None: diff --git a/test-data/unit/stubgen.test b/test-data/unit/stubgen.test index e64c9c66d65d..c48bd1f8e58c 100644 --- a/test-data/unit/stubgen.test +++ b/test-data/unit/stubgen.test @@ -385,7 +385,7 @@ class A: [out] class A: @property - def f(self): ... + def f(self) -> int: ... @f.setter def f(self, x) -> None: ... @f.deleter @@ -407,7 +407,7 @@ class A: [out] class A: @property - def f(self): ... + def f(self) -> int: ... @f.setter def f(self, x) -> None: ... @f.deleter @@ -444,7 +444,7 @@ import functools class A: @functools.cached_property - def x(self): ... + def x(self) -> str: ... [case testFunctoolsCachedPropertyAlias] import functools as ft @@ -458,7 +458,7 @@ import functools as ft class A: @ft.cached_property - def x(self): ... + def x(self) -> str: ... [case testCachedProperty] from functools import cached_property @@ -472,7 +472,7 @@ from functools import cached_property class A: @cached_property - def x(self): ... + def x(self) -> str: ... [case testCachedPropertyAlias] from functools import cached_property as cp @@ -486,7 +486,7 @@ from functools import cached_property as cp class A: @cp - def x(self): ... + def x(self) -> str: ... [case testStaticMethod] class A: @@ -1611,7 +1611,7 @@ def f(): def g(): return [out] -def f(): ... +def f() -> int: ... def g() -> None: ... [case testFunctionEllipsisInfersReturnNone] @@ -1754,9 +1754,9 @@ async def g(): return 2 [out] class F: - async def f(self): ... + async def f(self) -> int: ... -async def g(): ... +async def g() -> int: ... [case testCoroutineImportAsyncio] import asyncio @@ -1778,12 +1778,12 @@ import asyncio class F: @asyncio.coroutine - def f(self): ... + def f(self) -> int: ... @asyncio.coroutine -def g(): ... +def g() -> int: ... @asyncio.coroutine -def h(): ... +def h() -> int: ... [case testCoroutineImportAsyncioCoroutines] import asyncio.coroutines @@ -1801,10 +1801,10 @@ import asyncio.coroutines class F: @asyncio.coroutines.coroutine - def f(self): ... + def f(self) -> int: ... @asyncio.coroutines.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineImportAsyncioCoroutinesSub] import asyncio @@ -1822,10 +1822,10 @@ import asyncio class F: @asyncio.coroutines.coroutine - def f(self): ... + def f(self) -> int: ... @asyncio.coroutines.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineImportTypes] import types @@ -1843,10 +1843,10 @@ import types class F: @types.coroutine - def f(self): ... + def f(self) -> int: ... @types.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineFromAsyncioImportCoroutine] from asyncio import coroutine @@ -1864,10 +1864,10 @@ from asyncio import coroutine class F: @coroutine - def f(self): ... + def f(self) -> int: ... @coroutine -def g(): ... +def g() -> int: ... [case testCoroutineFromAsyncioCoroutinesImportCoroutine] from asyncio.coroutines import coroutine @@ -1885,10 +1885,10 @@ from asyncio.coroutines import coroutine class F: @coroutine - def f(self): ... + def f(self) -> int: ... @coroutine -def g(): ... +def g() -> int: ... [case testCoroutineFromTypesImportCoroutine] from types import coroutine @@ -1906,10 +1906,10 @@ from types import coroutine class F: @coroutine - def f(self): ... + def f(self) -> int: ... @coroutine -def g(): ... +def g() -> int: ... [case testCoroutineFromAsyncioImportCoroutineAsC] from asyncio import coroutine as c @@ -1927,10 +1927,10 @@ from asyncio import coroutine as c class F: @c - def f(self): ... + def f(self) -> int: ... @c -def g(): ... +def g() -> int: ... [case testCoroutineFromAsyncioCoroutinesImportCoroutineAsC] from asyncio.coroutines import coroutine as c @@ -1948,10 +1948,10 @@ from asyncio.coroutines import coroutine as c class F: @c - def f(self): ... + def f(self) -> int: ... @c -def g(): ... +def g() -> int: ... [case testCoroutineFromTypesImportCoroutineAsC] from types import coroutine as c @@ -1969,10 +1969,10 @@ from types import coroutine as c class F: @c - def f(self): ... + def f(self) -> int: ... @c -def g(): ... +def g() -> int: ... [case testCoroutineImportAsyncioAsA] import asyncio as a @@ -1990,10 +1990,10 @@ import asyncio as a class F: @a.coroutine - def f(self): ... + def f(self) -> int: ... @a.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineImportAsyncioCoroutinesAsC] import asyncio.coroutines as c @@ -2011,10 +2011,10 @@ import asyncio.coroutines as c class F: @c.coroutine - def f(self): ... + def f(self) -> int: ... @c.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineImportAsyncioCoroutinesSubAsA] import asyncio as a @@ -2032,10 +2032,10 @@ import asyncio as a class F: @a.coroutines.coroutine - def f(self): ... + def f(self) -> int: ... @a.coroutines.coroutine -def g(): ... +def g() -> int: ... [case testCoroutineImportTypesAsT] import types as t @@ -2053,10 +2053,10 @@ import types as t class F: @t.coroutine - def f(self): ... + def f(self) -> int: ... @t.coroutine -def g(): ... +def g() -> int: ... -- Tests for stub generation from semantically analyzed trees. @@ -2302,7 +2302,7 @@ def f(x): return '' [out] -def f(x): ... +def f(x) -> str: ... [case testFunctionPartiallyAnnotated] def f(x) -> None: @@ -2415,11 +2415,11 @@ class B: class A: y: str @property - def x(self): ... + def x(self) -> str: ... class B: @property - def x(self): ... + def x(self) -> str: ... y: str @x.setter def x(self, value) -> None: ... @@ -2592,7 +2592,7 @@ x: _Incomplete class Incomplete: ... -def Optional(): ... +def Optional() -> int: ... [case testExportedNameImported] # modules: main a b @@ -3586,8 +3586,8 @@ def f2(): yield from [0] return 0 [out] -def f1(): ... -def f2(): ... +def f1() -> int: ... +def f2() -> int: ... [case testIncludeDocstrings] # flags: --include-docstrings @@ -4508,3 +4508,72 @@ class C3[T3 = int]: ... class C4[T4: int | float = int](list[T4]): ... def f5[T5 = int]() -> None: ... + +[case testReturnTypeInference] +def f1(): + return 1 +def f2(): + return True +def f3(): + return None +def f4(x): + if x > 1: + return True + else: + return False +def f4(x): + if x > 1: + return 1 + else: + return None +def f5(x): + if x > 1: + return 1 + elif x < 1: + return 0 + else: + return None +def f6(x): + if x > 1: + return 1 + elif x < 1: + return None + else: + return None +def f7(x): + return x > 1 +def f8(): + return [1, 2, 3] +def f9(): + return [1, True] +def f10(): + return [] +def f11(): + return {1, 2, 3} +def f12(): + return {1, True} +def f13(): + return {} +def f14(): + return { "x": 1, "y": 2 } +def f15(): + return { "x": 1, "y": True } + +[out] +from typing import Dict, List, Set + +def f1() -> int: ... +def f2() -> bool: ... +def f3() -> None: ... +def f4(x) -> bool: ... +def f5(x) -> int | None: ... +def f6(x) -> int | None: ... +def f7(x) -> bool: ... +def f8() -> List[int]: ... +def f9(): ... +def f10(): ... +def f11() -> Set[int]: ... +def f12(): ... +def f13(): ... +def f14() -> Dict[str, int]: ... +def f15(): ...