diff --git a/interactions/client/client.py b/interactions/client/client.py index ff00b3642..c61efd18a 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -26,6 +26,8 @@ Union, Awaitable, Tuple, + TypeVar, + overload, ) from aiohttp import BasicAuth @@ -40,6 +42,7 @@ from interactions.client import errors from interactions.client.const import ( GLOBAL_SCOPE, + Missing, MISSING, Absent, EMBED_MAX_DESC_LENGTH, @@ -122,6 +125,8 @@ if TYPE_CHECKING: from interactions.models import Snowflake_Type, TYPE_ALL_CHANNEL +EventT = TypeVar("EventT", bound=BaseEvent) + __all__ = ("Client",) # see https://discord.com/developers/docs/topics/gateway#list-of-intents @@ -1061,12 +1066,36 @@ async def wait_until_ready(self) -> None: """Waits for the client to become ready.""" await self._ready.wait() + @overload + def wait_for( + self, + event: type[EventT], + checks: Absent[Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]]] = MISSING, + timeout: Optional[float] = None, + ) -> "Awaitable[EventT]": ... + + @overload def wait_for( self, - event: Union[str, "BaseEvent"], - checks: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] = MISSING, + event: str, + checks: Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]], timeout: Optional[float] = None, - ) -> Any: + ) -> "Awaitable[EventT]": ... + + @overload + def wait_for( + self, + event: str, + checks: Missing = MISSING, + timeout: Optional[float] = None, + ) -> Awaitable[Any]: ... + + def wait_for( + self, + event: Union[str, "type[BaseEvent]"], + checks: Absent[Callable[[BaseEvent], bool] | Callable[[BaseEvent], Awaitable[bool]]] = MISSING, + timeout: Optional[float] = None, + ) -> Awaitable[Any]: """ Waits for a WebSocket event to be dispatched. @@ -1112,7 +1141,7 @@ async def wait_for_modal( """ author = to_snowflake(author) if author else None - def predicate(event) -> bool: + def predicate(event: events.ModalCompletion) -> bool: if modal.custom_id != event.ctx.custom_id: return False return author == to_snowflake(event.ctx.author) if author else True @@ -1120,9 +1149,60 @@ def predicate(event) -> bool: resp = await self.wait_for("modal_completion", predicate, timeout) return resp.ctx + @overload + async def wait_for_component( + self, + messages: Union[Message, int, list], + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": ... + + @overload + async def wait_for_component( + self, + *, + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": ... + + @overload + async def wait_for_component( + self, + messages: None, + components: Union[ + List[List[Union["BaseComponent", dict]]], + List[Union["BaseComponent", dict]], + "BaseComponent", + dict, + ], + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": ... + + @overload + async def wait_for_component( + self, + messages: Union[Message, int, list], + components: None = None, + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, + timeout: Optional[float] = None, + ) -> "events.Component": ... + async def wait_for_component( self, - messages: Union[Message, int, list] = None, + messages: Optional[Union[Message, int, list]] = None, components: Optional[ Union[ List[List[Union["BaseComponent", dict]]], @@ -1131,7 +1211,7 @@ async def wait_for_component( dict, ] ] = None, - check: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] | None = None, + check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None, timeout: Optional[float] = None, ) -> "events.Component": """