diff --git a/interactions/__init__.py b/interactions/__init__.py index 0d7ee034d..ee2cbf5eb 100644 --- a/interactions/__init__.py +++ b/interactions/__init__.py @@ -37,6 +37,7 @@ smart_cache, T, T_co, + ClientT, utils, ) from .client import const @@ -420,6 +421,7 @@ "ChannelType", "check", "Client", + "ClientT", "ClientUser", "Color", "COLOR_TYPES", diff --git a/interactions/client/__init__.py b/interactions/client/__init__.py index 96df96f0a..3933e3d91 100644 --- a/interactions/client/__init__.py +++ b/interactions/client/__init__.py @@ -32,6 +32,7 @@ Absent, T, T_co, + ClientT, ) from .client import Client from .auto_shard_client import AutoShardedClient @@ -73,6 +74,7 @@ "Absent", "T", "T_co", + "ClientT", "Client", "AutoShardedClient", "smart_cache", diff --git a/interactions/client/client.py b/interactions/client/client.py index ff00b3642..4aaee8665 100644 --- a/interactions/client/client.py +++ b/interactions/client/client.py @@ -12,6 +12,7 @@ import traceback from collections.abc import Iterable from datetime import datetime +from typing_extensions import Self from typing import ( TYPE_CHECKING, Any, @@ -362,17 +363,17 @@ def __init__( """The HTTP client to use when interacting with discord endpoints""" # context factories - self.interaction_context: Type[BaseContext] = interaction_context + self.interaction_context: Type[BaseContext[Self]] = interaction_context """The object to instantiate for Interaction Context""" - self.component_context: Type[BaseContext] = component_context + self.component_context: Type[BaseContext[Self]] = component_context """The object to instantiate for Component Context""" - self.autocomplete_context: Type[BaseContext] = autocomplete_context + self.autocomplete_context: Type[BaseContext[Self]] = autocomplete_context """The object to instantiate for Autocomplete Context""" - self.modal_context: Type[BaseContext] = modal_context + self.modal_context: Type[BaseContext[Self]] = modal_context """The object to instantiate for Modal Context""" - self.slash_context: Type[BaseContext] = slash_context + self.slash_context: Type[BaseContext[Self]] = slash_context """The object to instantiate for Slash Context""" - self.context_menu_context: Type[BaseContext] = context_menu_context + self.context_menu_context: Type[BaseContext[Self]] = context_menu_context """The object to instantiate for Context Menu Context""" self.token: str | None = token @@ -1746,7 +1747,7 @@ def update_command_cache(self, scope: "Snowflake_Type", command_name: str, comma command.cmd_id[scope] = command_id self._interaction_lookup[command.resolved_name] = command - async def get_context(self, data: dict) -> InteractionContext: + async def get_context(self, data: dict) -> InteractionContext[Self]: match data["type"]: case InteractionType.MESSAGE_COMPONENT: cls = self.component_context.from_dict(self, data) diff --git a/interactions/client/const.py b/interactions/client/const.py index d1b4f04e7..9e97e53e0 100644 --- a/interactions/client/const.py +++ b/interactions/client/const.py @@ -43,7 +43,8 @@ import sys from collections import defaultdict from importlib.metadata import version as _v, PackageNotFoundError -from typing import TypeVar, Union, Callable, Coroutine, ClassVar +import typing_extensions +from typing import TypeVar, Union, Callable, Coroutine, ClassVar, TYPE_CHECKING __all__ = ( "__version__", @@ -79,6 +80,7 @@ "Absent", "T", "T_co", + "ClientT", "LIB_PATH", "RECOVERABLE_WEBSOCKET_CLOSE_CODES", "NON_RESUMABLE_WEBSOCKET_CLOSE_CODES", @@ -234,6 +236,13 @@ def has_client_feature(feature: str) -> bool: Absent = Union[T, Missing] AsyncCallable = Callable[..., Coroutine] +if TYPE_CHECKING: + from interactions import Client + + ClientT = typing_extensions.TypeVar("ClientT", bound=Client, default=Client) +else: + ClientT = TypeVar("ClientT") + LIB_PATH = os.sep.join(__file__.split(os.sep)[:-2]) """The path to the library folder.""" diff --git a/interactions/ext/hybrid_commands/context.py b/interactions/ext/hybrid_commands/context.py index 195f83bca..59684f7bb 100644 --- a/interactions/ext/hybrid_commands/context.py +++ b/interactions/ext/hybrid_commands/context.py @@ -9,7 +9,6 @@ Permissions, Message, SlashContext, - Client, Typing, Embed, BaseComponent, @@ -24,6 +23,7 @@ process_message_payload, TYPE_MESSAGEABLE_CHANNEL, ) +from interactions.client.const import ClientT from interactions.models.discord.enums import ContextType from interactions.client.mixins.send import SendMixin from interactions.client.errors import HTTPException @@ -37,7 +37,7 @@ class DeferTyping: - def __init__(self, ctx: "HybridContext", ephermal: bool) -> None: + def __init__(self, ctx: "SlashContext[ClientT]", ephermal: bool) -> None: self.ctx = ctx self.ephermal = ephermal @@ -48,7 +48,7 @@ async def __aexit__(self, *_) -> None: pass -class HybridContext(BaseContext, SendMixin): +class HybridContext(BaseContext[ClientT], SendMixin): prefix: str "The prefix used to invoke this command." @@ -76,10 +76,10 @@ class HybridContext(BaseContext, SendMixin): __attachment_index__: int - _slash_ctx: SlashContext | None - _prefixed_ctx: prefixed.PrefixedContext | None + _slash_ctx: SlashContext[ClientT] | None + _prefixed_ctx: prefixed.PrefixedContext[ClientT] | None - def __init__(self, client: Client): + def __init__(self, client: ClientT): super().__init__(client) self.prefix = "" self.app_permissions = Permissions(0) @@ -96,12 +96,12 @@ def __init__(self, client: Client): self._prefixed_ctx = None @classmethod - def from_dict(cls, client: Client, payload: dict) -> None: + def from_dict(cls, client: ClientT, payload: dict) -> None: # this doesn't mean anything, so just implement it to make abc happy raise NotImplementedError @classmethod - def from_slash_context(cls, ctx: SlashContext) -> Self: + def from_slash_context(cls, ctx: SlashContext[ClientT]) -> Self: self = cls(ctx.client) self.guild_id = ctx.guild_id self.channel_id = ctx.channel_id @@ -120,7 +120,7 @@ def from_slash_context(cls, ctx: SlashContext) -> Self: return self @classmethod - def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self: + def from_prefixed_context(cls, ctx: prefixed.PrefixedContext[ClientT]) -> Self: # this is a "best guess" on what the permissions are # this may or may not be totally accurate if hasattr(ctx.channel, "permissions_for"): @@ -162,7 +162,7 @@ def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self: return self @property - def inner_context(self) -> SlashContext | prefixed.PrefixedContext: + def inner_context(self) -> SlashContext[ClientT] | prefixed.PrefixedContext[ClientT]: """The inner context that this hybrid context is wrapping.""" return self._slash_ctx or self._prefixed_ctx # type: ignore diff --git a/interactions/ext/prefixed_commands/context.py b/interactions/ext/prefixed_commands/context.py index a9d995faf..b09249665 100644 --- a/interactions/ext/prefixed_commands/context.py +++ b/interactions/ext/prefixed_commands/context.py @@ -2,7 +2,7 @@ from typing_extensions import Self -from interactions.client.client import Client +from interactions.client.const import ClientT from interactions.client.mixins.send import SendMixin from interactions.models.discord.channel import TYPE_MESSAGEABLE_CHANNEL from interactions.models.discord.embed import Embed @@ -17,7 +17,7 @@ __all__ = ("PrefixedContext",) -class PrefixedContext(BaseContext, SendMixin): +class PrefixedContext(BaseContext[ClientT], SendMixin): _message: Message prefix: str @@ -33,12 +33,12 @@ class PrefixedContext(BaseContext, SendMixin): "This is always empty, and is only here for compatibility with other types of commands." @classmethod - def from_dict(cls, client: "Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: # this doesn't mean anything, so just implement it to make abc happy raise NotImplementedError @classmethod - def from_message(cls, client: "Client", message: Message) -> Self: + def from_message(cls, client: "ClientT", message: Message) -> Self: instance = cls(client=client) # hack to work around BaseContext property diff --git a/interactions/models/internal/context.py b/interactions/models/internal/context.py index 1f826627f..937ab9da4 100644 --- a/interactions/models/internal/context.py +++ b/interactions/models/internal/context.py @@ -9,7 +9,7 @@ from aiohttp import FormData from interactions.client import const -from interactions.client.const import get_logger, MISSING +from interactions.client.const import get_logger, MISSING, ClientT from interactions.models.discord.components import BaseComponent from interactions.models.discord.file import UPLOADABLE_TYPE from interactions.models.discord.sticker import Sticker @@ -148,7 +148,7 @@ def from_dict(cls, client: "interactions.Client", data: dict, guild_id: None | S return instance -class BaseContext(metaclass=abc.ABCMeta): +class BaseContext(typing.Generic[ClientT], metaclass=abc.ABCMeta): """ Base context class for all contexts. @@ -156,9 +156,6 @@ class BaseContext(metaclass=abc.ABCMeta): """ - client: "interactions.Client" - """The client that created this context.""" - command: BaseCommand """The command this context invokes.""" @@ -172,8 +169,10 @@ class BaseContext(metaclass=abc.ABCMeta): guild_id: typing.Optional[Snowflake] """The id of the guild this context was invoked in, if any.""" - def __init__(self, client: "interactions.Client") -> None: - self.client = client + def __init__(self, client: ClientT) -> None: + self.client: ClientT = client + """The client that created this context.""" + self.author_id = MISSING self.channel_id = MISSING self.message_id = MISSING @@ -217,12 +216,12 @@ def voice_state(self) -> typing.Optional["interactions.VoiceState"]: return self.client.cache.get_bot_voice_state(self.guild_id) @property - def bot(self) -> "interactions.Client": + def bot(self) -> "ClientT": return self.client @classmethod @abc.abstractmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: """ Create a context instance from a dict. @@ -237,7 +236,7 @@ def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: raise NotImplementedError -class BaseInteractionContext(BaseContext): +class BaseInteractionContext(BaseContext[ClientT]): token: str """The interaction token.""" id: Snowflake @@ -280,14 +279,14 @@ class BaseInteractionContext(BaseContext): kwargs: dict[str, typing.Any] """The keyword arguments passed to the interaction.""" - def __init__(self, client: "interactions.Client") -> None: + def __init__(self, client: "ClientT") -> None: super().__init__(client) self.deferred = False self.responded = False self.ephemeral = False @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: instance = cls(client=client) instance.token = payload["token"] instance.id = Snowflake(payload["id"]) @@ -417,7 +416,7 @@ def gather_options(_options: list[dict[str, typing.Any]]) -> dict[str, typing.An self.args = list(self.kwargs.values()) -class InteractionContext(BaseInteractionContext, SendMixin): +class InteractionContext(BaseInteractionContext[ClientT], SendMixin): async def defer(self, *, ephemeral: bool = False, suppress_error: bool = False) -> None: """ Defer the interaction. @@ -653,13 +652,13 @@ async def edit( return self.client.cache.place_message_data(message_data) -class SlashContext(InteractionContext, ModalMixin): +class SlashContext(InteractionContext[ClientT], ModalMixin): @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: return super().from_dict(client, payload) -class ContextMenuContext(InteractionContext, ModalMixin): +class ContextMenuContext(InteractionContext[ClientT], ModalMixin): target_id: Snowflake """The id of the target of the context menu.""" editing_origin: bool @@ -667,12 +666,12 @@ class ContextMenuContext(InteractionContext, ModalMixin): target_type: None | CommandType """The type of the target of the context menu.""" - def __init__(self, client: "interactions.Client") -> None: + def __init__(self, client: "ClientT") -> None: super().__init__(client) self.editing_origin = False @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: instance = super().from_dict(client, payload) instance.target_id = Snowflake(payload["data"]["target_id"]) instance.target_type = CommandType(payload["data"]["type"]) @@ -735,7 +734,7 @@ def target(self) -> None | Message | User | Member: return self.resolved.get(self.target_id) -class ComponentContext(InteractionContext, ModalMixin): +class ComponentContext(InteractionContext[ClientT], ModalMixin): values: list[str] """The values of the SelectMenu component, if any.""" custom_id: str @@ -746,7 +745,7 @@ class ComponentContext(InteractionContext, ModalMixin): """Whether you have deferred the interaction and are editing the original response.""" @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: instance = super().from_dict(client, payload) instance.values = payload["data"].get("values", []) instance.custom_id = payload["data"]["custom_id"] @@ -914,7 +913,7 @@ def component(self) -> typing.Optional[BaseComponent]: return component -class ModalContext(InteractionContext): +class ModalContext(InteractionContext[ClientT]): responses: dict[str, str] """The responses of the modal. The key is the `custom_id` of the component.""" custom_id: str @@ -923,7 +922,7 @@ class ModalContext(InteractionContext): """Whether to edit the original message instead of sending a new one.""" @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: instance = super().from_dict(client, payload) instance.responses = { comp["components"][0]["custom_id"]: comp["components"][0]["value"] for comp in payload["data"]["components"] @@ -984,12 +983,12 @@ async def _defer(self, *, ephemeral: bool = False, edit_origin: bool = False) -> self.ephemeral = ephemeral -class AutocompleteContext(BaseInteractionContext): +class AutocompleteContext(BaseInteractionContext[ClientT]): focussed_option: SlashCommandOption # todo: option parsing """The option the user is currently filling in.""" @classmethod - def from_dict(cls, client: "interactions.Client", payload: dict) -> Self: + def from_dict(cls, client: "ClientT", payload: dict) -> Self: return super().from_dict(client, payload) @property