From 4e06b507a74f7b995b2a49e394d0af3335d4b976 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Sat, 8 Jun 2024 11:39:59 -0400 Subject: [PATCH 1/2] feat/fix: improve typehinting of wait_fors --- interactions/client/client.py | 99 ++++++++++++++++++++++++++++++++--- 1 file changed, 93 insertions(+), 6 deletions(-) diff --git a/interactions/client/client.py b/interactions/client/client.py index ff00b3642..d91d02b79 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -26,6 +26,8 @@ Union, Awaitable, Tuple, + TypeVar, + overload, ) from aiohttp import BasicAuth @@ -40,6 +42,7 @@ from interactions.client import errors from interactions.client.const import ( GLOBAL_SCOPE, + Missing, MISSING, Absent, EMBED_MAX_DESC_LENGTH, @@ -122,6 +125,8 @@ if TYPE_CHECKING: from interactions.models import Snowflake_Type, TYPE_ALL_CHANNEL +EventT = TypeVar("EventT", bound=BaseEvent) + __all__ = ("Client",) # see https://discord.com/developers/docs/topics/gateway#list-of-intents @@ -1061,12 +1066,39 @@ async def wait_until_ready(self) -> None: """Waits for the client to become ready.""" await self._ready.wait() + @overload + def wait_for( + self, + event: type[EventT], + checks: Absent[Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]]] = MISSING, + timeout: Optional[float] = None, + ) -> "Awaitable[EventT]": + ... + + @overload def wait_for( self, - event: Union[str, "BaseEvent"], - checks: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] = MISSING, + event: str, + checks: Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]], timeout: Optional[float] = None, - ) -> Any: + ) -> "Awaitable[EventT]": + ... + + @overload + def wait_for( + self, + event: str, + checks: Missing = MISSING, + timeout: Optional[float] = None, + ) -> Awaitable[Any]: + ... + + def wait_for( + self, + event: Union[str, "type[BaseEvent]"], + checks: Absent[Callable[[BaseEvent], bool] | Callable[[BaseEvent], Awaitable[bool]]] = MISSING, + timeout: Optional[float] = None, + ) -> Awaitable[Any]: """ Waits for a WebSocket event to be dispatched. @@ -1112,7 +1144,7 @@ async def wait_for_modal( """ author = to_snowflake(author) if author else None - def predicate(event) -> bool: + def predicate(event: events.ModalCompletion) -> bool: if modal.custom_id != event.ctx.custom_id: return False return author == to_snowflake(event.ctx.author) if author else True @@ -1120,9 +1152,64 @@ def predicate(event) -> bool: resp = await self.wait_for("modal_completion", predicate, timeout) return resp.ctx + @overload + async def wait_for_component( + self, + messages: Union[Message, int, list], + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": + ... + + @overload + async def wait_for_component( + self, + *, + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": + ... + + @overload + async def wait_for_component( + self, + messages: None, + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": + ... + + @overload + async def wait_for_component( + self, + messages: Union[Message, int, list], + components: None = None, + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": + ... + async def wait_for_component( self, - messages: Union[Message, int, list] = None, + messages: Optional[Union[Message, int, list]] = None, components: Optional[ Union[ List[List[Union["BaseComponent", dict]]], @@ -1131,7 +1218,7 @@ async def wait_for_component( dict, ] ] = None, - check: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] | None = None, + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, ) -> "events.Component": """ From b87addf6dd1172b52e83da74fbe1d4627d66eb7d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Jun 2024 15:46:45 +0000 Subject: [PATCH 2/2] ci: correct from checks. --- interactions/client/client.py | 31 ++++++++++++------------------- 1 file changed, 12 insertions(+), 19 deletions(-) diff --git a/interactions/client/client.py b/interactions/client/client.py index d91d02b79..c61efd18a 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -1072,8 +1072,7 @@ def wait_for( event: type[EventT], checks: Absent[Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]]] = MISSING, timeout: Optional[float] = None, - ) -> "Awaitable[EventT]": - ... + ) -> "Awaitable[EventT]": ... @overload def wait_for( @@ -1081,8 +1080,7 @@ def wait_for( event: str, checks: Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]], timeout: Optional[float] = None, - ) -> "Awaitable[EventT]": - ... + ) -> "Awaitable[EventT]": ... @overload def wait_for( @@ -1090,8 +1088,7 @@ def wait_for( event: str, checks: Missing = MISSING, timeout: Optional[float] = None, - ) -> Awaitable[Any]: - ... + ) -> Awaitable[Any]: ... def wait_for( self, @@ -1157,15 +1154,14 @@ async def wait_for_component( self, messages: Union[Message, int, list], components: Union[ - List[List[Union["BaseComponent", dict]]], - List[Union["BaseComponent", dict]], - "BaseComponent", - dict, - ], + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, - ) -> "events.Component": - ... + ) -> "events.Component": ... @overload async def wait_for_component( @@ -1179,8 +1175,7 @@ async def wait_for_component( ], check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, - ) -> "events.Component": - ... + ) -> "events.Component": ... @overload async def wait_for_component( @@ -1194,8 +1189,7 @@ async def wait_for_component( ], check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, - ) -> "events.Component": - ... + ) -> "events.Component": ... @overload async def wait_for_component( @@ -1204,8 +1198,7 @@ async def wait_for_component( components: None = None, check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, - ) -> "events.Component": - ... + ) -> "events.Component": ... async def wait_for_component( self,