diff --git a/traits/stubs_tests/examples/Instance.py b/traits/stubs_tests/examples/Instance.py index 419962a72..a7e02f41d 100644 --- a/traits/stubs_tests/examples/Instance.py +++ b/traits/stubs_tests/examples/Instance.py @@ -8,14 +8,16 @@ # # Thanks for using Enthought open source! -from traits.api import HasTraits, Instance, Str +import typing + +from traits.api import HasTraits, Instance class Fruit: - info = Str("good for you") + info: str - def __init__(self, info="good for you", **traits): - super().__init__(info=info, **traits) + def __init__(self, info="good for you"): + self.info = info class Orange(Fruit): @@ -32,7 +34,9 @@ def fruit_factory(info, other_stuff): class TestClass(HasTraits): itm = Instance(Fruit) - + itm_not_none = Instance(Fruit, allow_none=False) + itm_allow_none = Instance(Fruit, allow_none=True) + itm_forward_ref = Instance("Fruit") itm_args_kw = Instance(Fruit, ('different info',), {'another_trait': 3}) itm_args = Instance(Fruit, ('different info',)) itm_kw = Instance(Fruit, {'info': 'different info'}) @@ -51,6 +55,34 @@ class TestClass(HasTraits): ) +def accepts_fruit(arg: Fruit) -> None: + pass + + +def accepts_fruit_or_none(arg: typing.Optional[Fruit]) -> None: + pass + + obj = TestClass() obj.itm = Orange() -obj.itm = Pizza() +obj.itm = None +obj.itm = Pizza() # E: assignment +obj.itm_allow_none = Orange() +obj.itm_allow_none = None +obj.itm_allow_none = Pizza() # E: assignment +obj.itm_not_none = Orange() +obj.itm_not_none = None # E: assignment +obj.itm_not_none = Pizza() # E: assignment +obj.itm_forward_ref = Orange() +obj.itm_forward_ref = None + + +obj = TestClass() +accepts_fruit(obj.itm) # E: arg-type +accepts_fruit_or_none(obj.itm) +accepts_fruit(obj.itm_allow_none) # E: arg-type +accepts_fruit_or_none(obj.itm_allow_none) +accepts_fruit(obj.itm_not_none) +accepts_fruit_or_none(obj.itm_not_none) +accepts_fruit(obj.itm_forward_ref) +accepts_fruit_or_none(obj.itm_forward_ref) diff --git a/traits/trait_types.pyi b/traits/trait_types.pyi index 6f4d2faaa..320a758f7 100644 --- a/traits/trait_types.pyi +++ b/traits/trait_types.pyi @@ -15,6 +15,7 @@ from typing import ( Callable as _CallableType, Dict as _DictType, List as _ListType, + Literal, Optional, Sequence as _Sequence, Set as _SetType, @@ -24,6 +25,7 @@ from typing import ( Type as _Type, TypeVar, Union as _Union, + overload, ) from uuid import UUID as _UUID @@ -494,24 +496,70 @@ class BaseClass(_BaseClass[_Type[_Any]]): ... -class _BaseInstance(_BaseClass[_T]): +class BaseInstance(_TraitType[_S, _S]): - # simplified signature + # simplified signatures + + @overload def __init__( - self, - klass: _T, + self: BaseInstance[Optional[_T]], + klass: _Type[_T], *args: _Any, + allow_none: Literal[True] = ..., **metadata: _Any, ) -> None: ... + @overload + def __init__( + self: BaseInstance[_T], + klass: _Type[_T], + *args: _Any, + allow_none: Literal[False] = ..., + **metadata: _Any, + ) -> None: + ... -class BaseInstance(_BaseInstance[_Any]): - ... + @overload + def __init__( + self: BaseInstance[_Any], + klass: str, + *args: _Any, + **metadata: _Any, + ) -> None: + ... -class Instance(_BaseInstance[_Any]): - ... +class Instance(BaseInstance[_S]): + + @overload + def __init__( + self: Instance[Optional[_T]], + klass: _Type[_T], + *args: _Any, + allow_none: Literal[True] = ..., + **metadata: _Any, + ) -> None: + ... + + @overload + def __init__( + self: Instance[_T], + klass: _Type[_T], + *args: _Any, + allow_none: Literal[False] = ..., + **metadata: _Any, + ) -> None: + ... + + @overload + def __init__( + self: Instance[_Any], + klass: str, + *args: _Any, + **metadata: _Any, + ) -> None: + ... class Supports(Instance):