Skip to content

feat: make context have generic client types #1699

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
smart_cache,
T,
T_co,
ClientT,
utils,
)
from .client import const
Expand Down Expand Up @@ -420,6 +421,7 @@
"ChannelType",
"check",
"Client",
"ClientT",
"ClientUser",
"Color",
"COLOR_TYPES",
Expand Down
2 changes: 2 additions & 0 deletions interactions/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
Absent,
T,
T_co,
ClientT,
)
from .client import Client
from .auto_shard_client import AutoShardedClient
Expand Down Expand Up @@ -73,6 +74,7 @@
"Absent",
"T",
"T_co",
"ClientT",
"Client",
"AutoShardedClient",
"smart_cache",
Expand Down
15 changes: 8 additions & 7 deletions interactions/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import traceback
from collections.abc import Iterable
from datetime import datetime
from typing_extensions import Self
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -362,17 +363,17 @@ def __init__(
"""The HTTP client to use when interacting with discord endpoints"""

# context factories
self.interaction_context: Type[BaseContext] = interaction_context
self.interaction_context: Type[BaseContext[Self]] = interaction_context
"""The object to instantiate for Interaction Context"""
self.component_context: Type[BaseContext] = component_context
self.component_context: Type[BaseContext[Self]] = component_context
"""The object to instantiate for Component Context"""
self.autocomplete_context: Type[BaseContext] = autocomplete_context
self.autocomplete_context: Type[BaseContext[Self]] = autocomplete_context
"""The object to instantiate for Autocomplete Context"""
self.modal_context: Type[BaseContext] = modal_context
self.modal_context: Type[BaseContext[Self]] = modal_context
"""The object to instantiate for Modal Context"""
self.slash_context: Type[BaseContext] = slash_context
self.slash_context: Type[BaseContext[Self]] = slash_context
"""The object to instantiate for Slash Context"""
self.context_menu_context: Type[BaseContext] = context_menu_context
self.context_menu_context: Type[BaseContext[Self]] = context_menu_context
"""The object to instantiate for Context Menu Context"""

self.token: str | None = token
Expand Down Expand Up @@ -1746,7 +1747,7 @@ def update_command_cache(self, scope: "Snowflake_Type", command_name: str, comma
command.cmd_id[scope] = command_id
self._interaction_lookup[command.resolved_name] = command

async def get_context(self, data: dict) -> InteractionContext:
async def get_context(self, data: dict) -> InteractionContext[Self]:
match data["type"]:
case InteractionType.MESSAGE_COMPONENT:
cls = self.component_context.from_dict(self, data)
Expand Down
11 changes: 10 additions & 1 deletion interactions/client/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,8 @@
import sys
from collections import defaultdict
from importlib.metadata import version as _v, PackageNotFoundError
from typing import TypeVar, Union, Callable, Coroutine, ClassVar
import typing_extensions
from typing import TypeVar, Union, Callable, Coroutine, ClassVar, TYPE_CHECKING

__all__ = (
"__version__",
Expand Down Expand Up @@ -79,6 +80,7 @@
"Absent",
"T",
"T_co",
"ClientT",
"LIB_PATH",
"RECOVERABLE_WEBSOCKET_CLOSE_CODES",
"NON_RESUMABLE_WEBSOCKET_CLOSE_CODES",
Expand Down Expand Up @@ -234,6 +236,13 @@ def has_client_feature(feature: str) -> bool:
Absent = Union[T, Missing]
AsyncCallable = Callable[..., Coroutine]

if TYPE_CHECKING:
from interactions import Client

ClientT = typing_extensions.TypeVar("ClientT", bound=Client, default=Client)
else:
ClientT = TypeVar("ClientT")

LIB_PATH = os.sep.join(__file__.split(os.sep)[:-2])
"""The path to the library folder."""

Expand Down
20 changes: 10 additions & 10 deletions interactions/ext/hybrid_commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
Permissions,
Message,
SlashContext,
Client,
Typing,
Embed,
BaseComponent,
Expand All @@ -24,6 +23,7 @@
process_message_payload,
TYPE_MESSAGEABLE_CHANNEL,
)
from interactions.client.const import ClientT
from interactions.models.discord.enums import ContextType
from interactions.client.mixins.send import SendMixin
from interactions.client.errors import HTTPException
Expand All @@ -37,7 +37,7 @@


class DeferTyping:
def __init__(self, ctx: "HybridContext", ephermal: bool) -> None:
def __init__(self, ctx: "SlashContext[ClientT]", ephermal: bool) -> None:
self.ctx = ctx
self.ephermal = ephermal

Expand All @@ -48,7 +48,7 @@ async def __aexit__(self, *_) -> None:
pass


class HybridContext(BaseContext, SendMixin):
class HybridContext(BaseContext[ClientT], SendMixin):
prefix: str
"The prefix used to invoke this command."

Expand Down Expand Up @@ -76,10 +76,10 @@ class HybridContext(BaseContext, SendMixin):

__attachment_index__: int

_slash_ctx: SlashContext | None
_prefixed_ctx: prefixed.PrefixedContext | None
_slash_ctx: SlashContext[ClientT] | None
_prefixed_ctx: prefixed.PrefixedContext[ClientT] | None

def __init__(self, client: Client):
def __init__(self, client: ClientT):
super().__init__(client)
self.prefix = ""
self.app_permissions = Permissions(0)
Expand All @@ -96,12 +96,12 @@ def __init__(self, client: Client):
self._prefixed_ctx = None

@classmethod
def from_dict(cls, client: Client, payload: dict) -> None:
def from_dict(cls, client: ClientT, payload: dict) -> None:
# this doesn't mean anything, so just implement it to make abc happy
raise NotImplementedError

@classmethod
def from_slash_context(cls, ctx: SlashContext) -> Self:
def from_slash_context(cls, ctx: SlashContext[ClientT]) -> Self:
self = cls(ctx.client)
self.guild_id = ctx.guild_id
self.channel_id = ctx.channel_id
Expand All @@ -120,7 +120,7 @@ def from_slash_context(cls, ctx: SlashContext) -> Self:
return self

@classmethod
def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self:
def from_prefixed_context(cls, ctx: prefixed.PrefixedContext[ClientT]) -> Self:
# this is a "best guess" on what the permissions are
# this may or may not be totally accurate
if hasattr(ctx.channel, "permissions_for"):
Expand Down Expand Up @@ -162,7 +162,7 @@ def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self:
return self

@property
def inner_context(self) -> SlashContext | prefixed.PrefixedContext:
def inner_context(self) -> SlashContext[ClientT] | prefixed.PrefixedContext[ClientT]:
"""The inner context that this hybrid context is wrapping."""
return self._slash_ctx or self._prefixed_ctx # type: ignore

Expand Down
8 changes: 4 additions & 4 deletions interactions/ext/prefixed_commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing_extensions import Self

from interactions.client.client import Client
from interactions.client.const import ClientT
from interactions.client.mixins.send import SendMixin
from interactions.models.discord.channel import TYPE_MESSAGEABLE_CHANNEL
from interactions.models.discord.embed import Embed
Expand All @@ -17,7 +17,7 @@
__all__ = ("PrefixedContext",)


class PrefixedContext(BaseContext, SendMixin):
class PrefixedContext(BaseContext[ClientT], SendMixin):
_message: Message

prefix: str
Expand All @@ -33,12 +33,12 @@ class PrefixedContext(BaseContext, SendMixin):
"This is always empty, and is only here for compatibility with other types of commands."

@classmethod
def from_dict(cls, client: "Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
# this doesn't mean anything, so just implement it to make abc happy
raise NotImplementedError

@classmethod
def from_message(cls, client: "Client", message: Message) -> Self:
def from_message(cls, client: "ClientT", message: Message) -> Self:
instance = cls(client=client)

# hack to work around BaseContext property
Expand Down
47 changes: 23 additions & 24 deletions interactions/models/internal/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from aiohttp import FormData

from interactions.client import const
from interactions.client.const import get_logger, MISSING
from interactions.client.const import get_logger, MISSING, ClientT
from interactions.models.discord.components import BaseComponent
from interactions.models.discord.file import UPLOADABLE_TYPE
from interactions.models.discord.sticker import Sticker
Expand Down Expand Up @@ -148,17 +148,14 @@ def from_dict(cls, client: "interactions.Client", data: dict, guild_id: None | S
return instance


class BaseContext(metaclass=abc.ABCMeta):
class BaseContext(typing.Generic[ClientT], metaclass=abc.ABCMeta):
"""
Base context class for all contexts.

Define your own context class by inheriting from this class. For compatibility with the library, you must define a `from_dict` classmethod that takes a dict and returns an instance of your context class.

"""

client: "interactions.Client"
"""The client that created this context."""

command: BaseCommand
"""The command this context invokes."""

Expand All @@ -172,8 +169,10 @@ class BaseContext(metaclass=abc.ABCMeta):
guild_id: typing.Optional[Snowflake]
"""The id of the guild this context was invoked in, if any."""

def __init__(self, client: "interactions.Client") -> None:
self.client = client
def __init__(self, client: ClientT) -> None:
self.client: ClientT = client
"""The client that created this context."""

self.author_id = MISSING
self.channel_id = MISSING
self.message_id = MISSING
Expand Down Expand Up @@ -217,12 +216,12 @@ def voice_state(self) -> typing.Optional["interactions.VoiceState"]:
return self.client.cache.get_bot_voice_state(self.guild_id)

@property
def bot(self) -> "interactions.Client":
def bot(self) -> "ClientT":
return self.client

@classmethod
@abc.abstractmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
"""
Create a context instance from a dict.

Expand All @@ -237,7 +236,7 @@ def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
raise NotImplementedError


class BaseInteractionContext(BaseContext):
class BaseInteractionContext(BaseContext[ClientT]):
token: str
"""The interaction token."""
id: Snowflake
Expand Down Expand Up @@ -280,14 +279,14 @@ class BaseInteractionContext(BaseContext):
kwargs: dict[str, typing.Any]
"""The keyword arguments passed to the interaction."""

def __init__(self, client: "interactions.Client") -> None:
def __init__(self, client: "ClientT") -> None:
super().__init__(client)
self.deferred = False
self.responded = False
self.ephemeral = False

@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
instance = cls(client=client)
instance.token = payload["token"]
instance.id = Snowflake(payload["id"])
Expand Down Expand Up @@ -417,7 +416,7 @@ def gather_options(_options: list[dict[str, typing.Any]]) -> dict[str, typing.An
self.args = list(self.kwargs.values())


class InteractionContext(BaseInteractionContext, SendMixin):
class InteractionContext(BaseInteractionContext[ClientT], SendMixin):
async def defer(self, *, ephemeral: bool = False, suppress_error: bool = False) -> None:
"""
Defer the interaction.
Expand Down Expand Up @@ -653,26 +652,26 @@ async def edit(
return self.client.cache.place_message_data(message_data)


class SlashContext(InteractionContext, ModalMixin):
class SlashContext(InteractionContext[ClientT], ModalMixin):
@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
return super().from_dict(client, payload)


class ContextMenuContext(InteractionContext, ModalMixin):
class ContextMenuContext(InteractionContext[ClientT], ModalMixin):
target_id: Snowflake
"""The id of the target of the context menu."""
editing_origin: bool
"""Whether you have deferred the interaction and are editing the original response."""
target_type: None | CommandType
"""The type of the target of the context menu."""

def __init__(self, client: "interactions.Client") -> None:
def __init__(self, client: "ClientT") -> None:
super().__init__(client)
self.editing_origin = False

@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
instance = super().from_dict(client, payload)
instance.target_id = Snowflake(payload["data"]["target_id"])
instance.target_type = CommandType(payload["data"]["type"])
Expand Down Expand Up @@ -735,7 +734,7 @@ def target(self) -> None | Message | User | Member:
return self.resolved.get(self.target_id)


class ComponentContext(InteractionContext, ModalMixin):
class ComponentContext(InteractionContext[ClientT], ModalMixin):
values: list[str]
"""The values of the SelectMenu component, if any."""
custom_id: str
Expand All @@ -746,7 +745,7 @@ class ComponentContext(InteractionContext, ModalMixin):
"""Whether you have deferred the interaction and are editing the original response."""

@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
instance = super().from_dict(client, payload)
instance.values = payload["data"].get("values", [])
instance.custom_id = payload["data"]["custom_id"]
Expand Down Expand Up @@ -914,7 +913,7 @@ def component(self) -> typing.Optional[BaseComponent]:
return component


class ModalContext(InteractionContext):
class ModalContext(InteractionContext[ClientT]):
responses: dict[str, str]
"""The responses of the modal. The key is the `custom_id` of the component."""
custom_id: str
Expand All @@ -923,7 +922,7 @@ class ModalContext(InteractionContext):
"""Whether to edit the original message instead of sending a new one."""

@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
instance = super().from_dict(client, payload)
instance.responses = {
comp["components"][0]["custom_id"]: comp["components"][0]["value"] for comp in payload["data"]["components"]
Expand Down Expand Up @@ -984,12 +983,12 @@ async def _defer(self, *, ephemeral: bool = False, edit_origin: bool = False) ->
self.ephemeral = ephemeral


class AutocompleteContext(BaseInteractionContext):
class AutocompleteContext(BaseInteractionContext[ClientT]):
focussed_option: SlashCommandOption # todo: option parsing
"""The option the user is currently filling in."""

@classmethod
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
return super().from_dict(client, payload)

@property
Expand Down
Loading