From 366ac5f3bb0f4ecd878b1fa81d785cf90ff4508f Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Wed, 24 Apr 2024 19:10:50 -0400 Subject: [PATCH 01/16] feat: initial work on polls --- interactions/__init__.py | 2 + interactions/client/const.py | 5 + interactions/models/__init__.py | 2 + interactions/models/discord/__init__.py | 2 + interactions/models/discord/enums.py | 7 ++ interactions/models/discord/poll.py | 155 ++++++++++++++++++++++++ 6 files changed, 173 insertions(+) create mode 100644 interactions/models/discord/poll.py diff --git a/interactions/__init__.py b/interactions/__init__.py index 0d7ee034d..a8a6905cf 100644 --- a/interactions/__init__.py +++ b/interactions/__init__.py @@ -243,6 +243,7 @@ PartialEmojiConverter, PermissionOverwrite, Permissions, + PollLayoutType, PremiumTier, PremiumType, process_allowed_mentions, @@ -594,6 +595,7 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "PollLayoutType", "PREMIUM_GUILD_LIMITS", "PremiumTier", "PremiumType", diff --git a/interactions/client/const.py b/interactions/client/const.py index d1b4f04e7..a1e525a55 100644 --- a/interactions/client/const.py +++ b/interactions/client/const.py @@ -84,6 +84,8 @@ "NON_RESUMABLE_WEBSOCKET_CLOSE_CODES", "CLIENT_FEATURE_FLAGS", "has_client_feature", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", ) _ver_info = sys.version_info @@ -130,6 +132,9 @@ def get_logger() -> logging.Logger: EMBED_TOTAL_MAX = 6000 EMBED_FIELD_VALUE_LENGTH = 1024 +POLL_MAX_ANSWERS = 10 +POLL_MAX_DURATION_HOURS = 168 + class Singleton(type): _instances: ClassVar[dict] = {} diff --git a/interactions/models/__init__.py b/interactions/models/__init__.py index 6cd71f547..c033bfc27 100644 --- a/interactions/models/__init__.py +++ b/interactions/models/__init__.py @@ -122,6 +122,7 @@ PartialEmoji, PermissionOverwrite, Permissions, + PollLayoutType, PremiumTier, PremiumType, process_allowed_mentions, @@ -523,6 +524,7 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "PollLayoutType", "PremiumTier", "PremiumType", "process_allowed_mentions", diff --git a/interactions/models/discord/__init__.py b/interactions/models/discord/__init__.py index 91dec1218..f668e75a1 100644 --- a/interactions/models/discord/__init__.py +++ b/interactions/models/discord/__init__.py @@ -104,6 +104,7 @@ OnboardingPromptType, OverwriteType, Permissions, + PollLayoutType, PremiumTier, PremiumType, ScheduledEventPrivacyLevel, @@ -299,6 +300,7 @@ "PartialEmoji", "PermissionOverwrite", "Permissions", + "PollLayoutType", "PremiumTier", "PremiumType", "process_allowed_mentions", diff --git a/interactions/models/discord/enums.py b/interactions/models/discord/enums.py index 0c25b4eeb..166f0fef7 100644 --- a/interactions/models/discord/enums.py +++ b/interactions/models/discord/enums.py @@ -38,6 +38,7 @@ "OnboardingPromptType", "OverwriteType", "Permissions", + "PollLayoutType", "PremiumTier", "PremiumType", "ScheduledEventPrivacyLevel", @@ -1134,3 +1135,9 @@ class EntitlementType(CursedIntEnum): """Entitlement was claimed by user for free as a Nitro Subscriber""" APPLICATION_SUBSCRIPTION = 8 """Entitlement was purchased as an app subscription""" + + +class PollLayoutType(CursedIntEnum): + """The layout of a poll.""" + + DEFAULT = 1 diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py new file mode 100644 index 000000000..bc4dc1899 --- /dev/null +++ b/interactions/models/discord/poll.py @@ -0,0 +1,155 @@ +from typing import Optional, Union, Dict, Any +from typing_extensions import Self + +import attrs + +from interactions.client.const import MISSING +from interactions.client.utils.attr_converters import ( + optional, + timestamp_converter, +) +from interactions.client.mixins.serialization import DictSerializationMixin +from interactions.client.utils.serializer import no_export_meta +from interactions.models.discord.emoji import PartialEmoji, process_emoji +from interactions.models.discord.enums import PollLayoutType +from interactions.models.discord.timestamp import Timestamp + +__all__ = ( + "PollMedia", + "PollAnswer", + "PollAnswerCount", + "PollResults", + "Poll", +) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollMedia(DictSerializationMixin): + text: Optional[str] = attrs.field(repr=False, default=None) + """The text of the field.""" + emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict)) + """The emoji of the field.""" + + @classmethod + def create(cls, *, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self: + """ + Create a PollMedia object, used for questions and answers for polls. + + Args: + text: The text of the field. + emoji: The emoji of the field. + + Returns: + A PollMedia object. + + """ + if not text and not emoji: + raise ValueError("Either text or emoji must be provided.") + + return cls(text=text, emoji=process_emoji(emoji)) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollAnswer(DictSerializationMixin): + poll_media: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict) + """The data of the answer.""" + answer_id: Optional[int] = attrs.field(repr=False, default=None) + """The ID of the answer. This is only returned for polls that have been given by Discord's API.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollAnswerCount(DictSerializationMixin): + id: int = attrs.field(repr=False) + """The answer ID of the answer.""" + count: int = attrs.field(repr=False, default=0) + """The number of votes for this answer.""" + me_voted: bool = attrs.field(repr=False, default=False) + """Whether the current user voted for this answer.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollResults(DictSerializationMixin): + is_finalized: bool = attrs.field(repr=False, default=False) + """Whether the votes have been precisely counted.""" + answer_counts: list[PollAnswerCount] = attrs.field(repr=False, factory=list, converter=PollAnswerCount.from_list) + """The counts for each answer.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class Poll(DictSerializationMixin): + question: PollMedia = attrs.field(repr=False) + """The question of the poll. Only text media is supported.""" + answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list) + """Each of the answers available in the poll, up to 10.""" + expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=timestamp_converter) + """Number of hours the poll is open for, up to 7 days.""" + allow_multiselect: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) + """Whether a user can select multiple answers.""" + layout_type: PollLayoutType = attrs.field(repr=False, default=PollLayoutType.DEFAULT, converter=PollLayoutType) + """The layout type of the poll.""" + results: Optional[PollResults] = attrs.field(repr=False, default=None, converter=optional(PollResults.from_dict)) + """The results of the poll, if the polls is finished.""" + + _duration: int = attrs.field(repr=False, default=0) + """How long, in hours, the poll will be open for (up to 7 days). This is only used when creating polls.""" + + @classmethod + def create( + cls, question: str, duration: int, *, allow_multiselect: bool = False, answers: Optional[list[PollMedia]] = None + ) -> Self: + """ + Create a Poll object for sending. + + Args: + question: The question of the poll. + duration: How long, in hours, the poll will be open for (up to 7 days). + allow_multiselect: Whether a user can select multiple answers. + answers: Each of the answers available in the poll, up to 10. + + Returns: + A Poll object. + + """ + if answers: + media_to_answers = [PollAnswer(poll_media=answer) for answer in answers] + else: + media_to_answers = [] + + return cls( + question=PollMedia(text=question), + duration=duration, + allow_multiselect=allow_multiselect, + answers=media_to_answers, + ) + + @answers.validator + def _answers_validation(self, attribute: str, value: Any) -> None: + if len(value) > 10: + raise ValueError("A poll can have at most 10 answers.") + + @_duration.validator + def _duration_validation(self, attribute: str, value: int) -> None: + if value < 0 or value > 168: + raise ValueError("The duration must be between 0 and 168 hours (7 days).") + + def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> None: + """ + Adds an answer to the poll. + + Args: + text: The text of the answer. + emoji: The emoji for the answer. + + """ + if not text and not emoji: + raise ValueError("Either text or emoji must be provided") + + self.answers.append(PollAnswer(poll_media=PollMedia.create(text=text, emoji=emoji))) + self._answers_validation("answers", self.answers) + + def to_dict(self) -> Dict[str, Any]: + data = super().to_dict() + + data["duration"] = self._duration + data.pop("_duration", None) + return data From 6f76907677dfaa82b89f2af25eda04c7d12ceb86 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:27:10 -0400 Subject: [PATCH 02/16] refactor: use constants instead of hardcoded values --- interactions/models/discord/poll.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index bc4dc1899..e3975ce13 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -3,7 +3,7 @@ import attrs -from interactions.client.const import MISSING +from interactions.client.const import MISSING, POLL_MAX_DURATION_HOURS, POLL_MAX_ANSWERS from interactions.client.utils.attr_converters import ( optional, timestamp_converter, @@ -124,13 +124,15 @@ def create( @answers.validator def _answers_validation(self, attribute: str, value: Any) -> None: - if len(value) > 10: - raise ValueError("A poll can have at most 10 answers.") + if len(value) > POLL_MAX_ANSWERS: + raise ValueError(f"A poll can have at most {POLL_MAX_ANSWERS} answers.") @_duration.validator def _duration_validation(self, attribute: str, value: int) -> None: - if value < 0 or value > 168: - raise ValueError("The duration must be between 0 and 168 hours (7 days).") + if value < 0 or value > POLL_MAX_DURATION_HOURS: + raise ValueError( + f"The duration must be between 0 and {POLL_MAX_DURATION_HOURS} hours ({POLL_MAX_DURATION_HOURS // 24} days)." + ) def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> None: """ From 5b3715b5d62c68d571dc32866795d2a0cbecea9b Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:33:41 -0400 Subject: [PATCH 03/16] feat: expose poll objects --- interactions/__init__.py | 14 ++++++++++++++ interactions/client/__init__.py | 4 ++++ interactions/models/__init__.py | 10 ++++++++++ interactions/models/discord/__init__.py | 6 ++++++ 4 files changed, 34 insertions(+) diff --git a/interactions/__init__.py b/interactions/__init__.py index a8a6905cf..c20006619 100644 --- a/interactions/__init__.py +++ b/interactions/__init__.py @@ -25,6 +25,8 @@ MentionPrefix, Missing, MISSING, + POLL_MAX_ANSWERS, + POLL_MAX_DURATION_HOURS, PREMIUM_GUILD_LIMITS, SELECT_MAX_NAME_LENGTH, SELECTS_MAX_OPTIONS, @@ -243,7 +245,12 @@ PartialEmojiConverter, PermissionOverwrite, Permissions, + Poll, + PollAnswer, + PollAnswerCount, PollLayoutType, + PollMedia, + PollResults, PremiumTier, PremiumType, process_allowed_mentions, @@ -595,7 +602,14 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", "PollLayoutType", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", + "PollMedia", + "PollResults", "PREMIUM_GUILD_LIMITS", "PremiumTier", "PremiumType", diff --git a/interactions/client/__init__.py b/interactions/client/__init__.py index 96df96f0a..fea5b314c 100644 --- a/interactions/client/__init__.py +++ b/interactions/client/__init__.py @@ -29,6 +29,8 @@ MISSING, MENTION_PREFIX, PREMIUM_GUILD_LIMITS, + POLL_MAX_ANSWERS, + POLL_MAX_DURATION_HOURS, Absent, T, T_co, @@ -61,6 +63,8 @@ "EMBED_MAX_FIELDS", "EMBED_TOTAL_MAX", "EMBED_FIELD_VALUE_LENGTH", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", "Singleton", "Sentinel", "GlobalScope", diff --git a/interactions/models/__init__.py b/interactions/models/__init__.py index c033bfc27..153ba642f 100644 --- a/interactions/models/__init__.py +++ b/interactions/models/__init__.py @@ -122,7 +122,12 @@ PartialEmoji, PermissionOverwrite, Permissions, + Poll, + PollAnswer, + PollAnswerCount, PollLayoutType, + PollMedia, + PollResults, PremiumTier, PremiumType, process_allowed_mentions, @@ -524,7 +529,12 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", "PollLayoutType", + "PollMedia", + "PollResults", "PremiumTier", "PremiumType", "process_allowed_mentions", diff --git a/interactions/models/discord/__init__.py b/interactions/models/discord/__init__.py index f668e75a1..c34b6499d 100644 --- a/interactions/models/discord/__init__.py +++ b/interactions/models/discord/__init__.py @@ -156,6 +156,7 @@ ) from .modal import InputText, Modal, ParagraphText, ShortText, TextStyles from .onboarding import Onboarding, OnboardingPrompt, OnboardingPromptOption +from .poll import PollMedia, PollAnswer, PollAnswerCount, PollResults, Poll from .reaction import Reaction, ReactionUsers from .role import Role from .scheduled_event import ScheduledEvent @@ -300,7 +301,12 @@ "PartialEmoji", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", "PollLayoutType", + "PollMedia", + "PollResults", "PremiumTier", "PremiumType", "process_allowed_mentions", From 9561e568896b3a77ae2b9e2045da7f41b7e05a9a Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:49:55 -0400 Subject: [PATCH 04/16] fix: make converter optional --- interactions/models/discord/poll.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index e3975ce13..ddd5547ef 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -81,7 +81,7 @@ class Poll(DictSerializationMixin): """The question of the poll. Only text media is supported.""" answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list) """Each of the answers available in the poll, up to 10.""" - expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=timestamp_converter) + expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=optional(timestamp_converter)) """Number of hours the poll is open for, up to 7 days.""" allow_multiselect: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) """Whether a user can select multiple answers.""" From 316cff0073d5787e40044c0b4bc168dd5a6f12b3 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:50:51 -0400 Subject: [PATCH 05/16] feat: add polls to send functions --- interactions/client/mixins/send.py | 4 ++++ interactions/ext/hybrid_commands/context.py | 4 ++++ interactions/models/discord/message.py | 7 +++++++ interactions/models/discord/webhooks.py | 4 ++++ interactions/models/internal/context.py | 4 ++++ 5 files changed, 23 insertions(+) diff --git a/interactions/client/mixins/send.py b/interactions/client/mixins/send.py index 5776c78d5..7aeadc03a 100644 --- a/interactions/client/mixins/send.py +++ b/interactions/client/mixins/send.py @@ -10,6 +10,7 @@ from interactions.models.discord.components import BaseComponent from interactions.models.discord.embed import Embed from interactions.models.discord.message import AllowedMentions, Message, MessageReference + from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker from interactions.models.discord.snowflake import Snowflake_Type @@ -49,6 +50,7 @@ async def send( delete_after: Optional[float] = None, nonce: Optional[str | int] = None, enforce_nonce: bool = False, + poll: "Optional[Poll | dict]" = None, **kwargs: Any, ) -> "Message": """ @@ -73,6 +75,7 @@ async def send( enforce_nonce: If enabled and nonce is present, it will be checked for uniqueness in the past few minutes. \ If another message was created by the same author with the same nonce, that message will be returned \ and no new message will be created. + poll: A poll. Returns: New message object that was sent. @@ -115,6 +118,7 @@ async def send( flags=flags, nonce=nonce, enforce_nonce=enforce_nonce, + poll=poll, **kwargs, ) diff --git a/interactions/ext/hybrid_commands/context.py b/interactions/ext/hybrid_commands/context.py index 195f83bca..70fdac6c7 100644 --- a/interactions/ext/hybrid_commands/context.py +++ b/interactions/ext/hybrid_commands/context.py @@ -23,6 +23,7 @@ Attachment, process_message_payload, TYPE_MESSAGEABLE_CHANNEL, + Poll, ) from interactions.models.discord.enums import ContextType from interactions.client.mixins.send import SendMixin @@ -309,6 +310,7 @@ async def send( suppress_embeds: bool = False, silent: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, + poll: "Optional[Poll | dict]" = None, delete_after: Optional[float] = None, ephemeral: bool = False, **kwargs: Any, @@ -330,6 +332,7 @@ async def send( suppress_embeds: Should embeds be suppressed on this send silent: Should this message be sent without triggering a notification. flags: Message flags to apply. + poll: A poll. delete_after: Delete message after this many seconds. ephemeral: Should this message be sent as ephemeral (hidden) - only works with interactions @@ -358,6 +361,7 @@ async def send( file=file, tts=tts, flags=flags, + poll=poll, delete_after=delete_after, pass_self_into_delete=bool(self._slash_ctx), **kwargs, diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index a214b8800..ffc769018 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -28,6 +28,7 @@ from interactions.models.discord.embed import process_embeds from interactions.models.discord.emoji import process_emoji_req_format from interactions.models.discord.file import UPLOADABLE_TYPE +from interactions.models.discord.poll import Poll from .base import DiscordObject from .enums import ( @@ -981,6 +982,7 @@ def process_message_payload( flags: Optional[Union[int, MessageFlags]] = None, nonce: Optional[str | int] = None, enforce_nonce: bool = False, + poll: Optional[Poll | dict] = None, **kwargs, ) -> dict: """ @@ -1000,6 +1002,7 @@ def process_message_payload( enforce_nonce: If enabled and nonce is present, it will be checked for uniqueness in the past few minutes. \ If another message was created by the same author with the same nonce, that message will be returned \ and no new message will be created. + poll: A poll. Returns: Dictionary @@ -1017,6 +1020,9 @@ def process_message_payload( if attachments: attachments = [attachment.to_dict() for attachment in attachments] + if isinstance(poll, Poll): + poll = poll.to_dict() + return dict_filter_none( { "content": content, @@ -1030,6 +1036,7 @@ def process_message_payload( "flags": flags, "nonce": nonce, "enforce_nonce": enforce_nonce, + "poll": poll, **kwargs, } ) diff --git a/interactions/models/discord/webhooks.py b/interactions/models/discord/webhooks.py index f474662fd..548eaec9f 100644 --- a/interactions/models/discord/webhooks.py +++ b/interactions/models/discord/webhooks.py @@ -27,6 +27,7 @@ Message, MessageReference, ) + from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker __all__ = ("WebhookTypes", "Webhook") @@ -190,6 +191,7 @@ async def send( tts: bool = False, suppress_embeds: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, + poll: "Optional[Poll | dict]" = None, username: str | None = None, avatar_url: str | None = None, wait: bool = False, @@ -212,6 +214,7 @@ async def send( tts: Should this message use Text To Speech. suppress_embeds: Should embeds be suppressed on this send flags: Message flags to apply. + poll: A poll. username: The username to use avatar_url: The url of an image to use as the avatar wait: Waits for confirmation of delivery. Set this to True if you intend to edit the message @@ -241,6 +244,7 @@ async def send( reply_to=reply_to, tts=tts, flags=flags, + poll=poll, username=username, avatar_url=avatar_url, **kwargs, diff --git a/interactions/models/internal/context.py b/interactions/models/internal/context.py index 1f826627f..859161bb3 100644 --- a/interactions/models/internal/context.py +++ b/interactions/models/internal/context.py @@ -12,6 +12,7 @@ from interactions.client.const import get_logger, MISSING from interactions.models.discord.components import BaseComponent from interactions.models.discord.file import UPLOADABLE_TYPE +from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker from interactions.models.discord.user import Member, User @@ -540,6 +541,7 @@ async def send( suppress_embeds: bool = False, silent: bool = False, flags: typing.Optional[typing.Union[int, "MessageFlags"]] = None, + poll: "typing.Optional[Poll | dict]" = None, delete_after: typing.Optional[float] = None, ephemeral: bool = False, **kwargs: typing.Any, @@ -561,6 +563,7 @@ async def send( suppress_embeds: Should embeds be suppressed on this send silent: Should this message be sent without triggering a notification. flags: Message flags to apply. + poll: A poll. delete_after: Delete message after this many seconds. ephemeral: Whether the response should be ephemeral @@ -589,6 +592,7 @@ async def send( file=file, tts=tts, flags=flags, + poll=poll, delete_after=delete_after, pass_self_into_delete=True, **kwargs, From 21f31fabfd3a917f568d756e3f2954d0e01cd274 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 16:51:50 -0400 Subject: [PATCH 06/16] feat: allow chaining of add_answer --- interactions/models/discord/poll.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index ddd5547ef..9c6e8dbd3 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -134,7 +134,7 @@ def _duration_validation(self, attribute: str, value: int) -> None: f"The duration must be between 0 and {POLL_MAX_DURATION_HOURS} hours ({POLL_MAX_DURATION_HOURS // 24} days)." ) - def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> None: + def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self: """ Adds an answer to the poll. @@ -148,6 +148,7 @@ def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEm self.answers.append(PollAnswer(poll_media=PollMedia.create(text=text, emoji=emoji))) self._answers_validation("answers", self.answers) + return self def to_dict(self) -> Dict[str, Any]: data = super().to_dict() From 4c6015cefdd312251245d7906b06bc5e75eec802 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:00:09 -0400 Subject: [PATCH 07/16] feat: a couple of create changes --- interactions/models/discord/poll.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index 9c6e8dbd3..e61a1f157 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -95,7 +95,12 @@ class Poll(DictSerializationMixin): @classmethod def create( - cls, question: str, duration: int, *, allow_multiselect: bool = False, answers: Optional[list[PollMedia]] = None + cls, + question: str, + *, + duration: int, + allow_multiselect: bool = False, + answers: Optional[list[PollMedia | str]] = None, ) -> Self: """ Create a Poll object for sending. @@ -111,7 +116,14 @@ def create( """ if answers: - media_to_answers = [PollAnswer(poll_media=answer) for answer in answers] + media_to_answers = [ + ( + PollAnswer(poll_media=answer) + if isinstance(answer, PollMedia) + else PollAnswer(poll_media=PollMedia.create(text=answer)) + ) + for answer in answers + ] else: media_to_answers = [] From decbecf9815a58590c09f255c63161ad7953e90a Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:01:35 -0400 Subject: [PATCH 08/16] docs: add poll docs --- docs/src/API Reference/API Reference/models/Discord/index.md | 1 + docs/src/API Reference/API Reference/models/Discord/poll.md | 1 + 2 files changed, 2 insertions(+) create mode 100644 docs/src/API Reference/API Reference/models/Discord/poll.md diff --git a/docs/src/API Reference/API Reference/models/Discord/index.md b/docs/src/API Reference/API Reference/models/Discord/index.md index 1b45617cd..b97fffa0e 100644 --- a/docs/src/API Reference/API Reference/models/Discord/index.md +++ b/docs/src/API Reference/API Reference/models/Discord/index.md @@ -22,6 +22,7 @@ search: - [Invite](invite) - [Message](message) - [Modals](modals) +- [Poll](poll) - [Reaction](reaction) - [Role](role) - [Scheduled event](scheduled_event) diff --git a/docs/src/API Reference/API Reference/models/Discord/poll.md b/docs/src/API Reference/API Reference/models/Discord/poll.md new file mode 100644 index 000000000..7f53c7c3d --- /dev/null +++ b/docs/src/API Reference/API Reference/models/Discord/poll.md @@ -0,0 +1 @@ +::: interactions.models.discord.poll From 0a652bac699f42654e700aff58e77fa47cc3bc0c Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:23:40 -0400 Subject: [PATCH 09/16] feat: add http methods for polls --- .../api/http/http_requests/messages.py | 57 +++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/interactions/api/http/http_requests/messages.py b/interactions/api/http/http_requests/messages.py index a836e269f..012b7fc0b 100644 --- a/interactions/api/http/http_requests/messages.py +++ b/interactions/api/http/http_requests/messages.py @@ -3,6 +3,7 @@ import discord_typings from interactions.models.internal.protocols import CanRequest +from interactions.client.utils.serializer import dict_filter_none from ..route import Route __all__ = ("MessageRequests",) @@ -175,3 +176,59 @@ async def crosspost_message( ) ) return cast(discord_typings.MessageData, result) + + async def get_answer_voters( + self, + channel_id: "Snowflake_Type", + message_id: "Snowflake_Type", + answer_id: int, + after: "Snowflake_Type | None" = None, + limit: int = 25, + ) -> list[discord_typings.UserData]: + """ + Get a list of users that voted for this specific answer. + + Args: + channel_id: Channel the message is in + message_id: The message with the poll + answer_id: The answer to get voters for + after: Get messages after this user ID + limit: The max number of users to return (default 25, max 100) + + Returns: + list[discord_typings.UserData]: A list of users that voted for the answer + + """ + result = await self.request( + Route( + "GET", + "/channels/{channel_id}/messages/{message_id}/polls/{answer_id}/votes", + channel_id=channel_id, + message_id=message_id, + answer_id=answer_id, + ), + params=dict_filter_none({"after": after, "limit": limit}), + ) + return cast(list[discord_typings.UserData], result) + + async def end_poll(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type") -> discord_typings.MessageData: + """ + Ends a poll. Only can end polls from the current bot. + + Args: + channel_id: Channel the message is in + message_id: The message with the poll + + Returns: + message object + + """ + result = await self.request( + Route( + "POST", + "/channels/{channel_id}/messages/{message_id}/polls/end", + channel_id=channel_id, + message_id=message_id, + ) + ) + return cast(discord_typings.MessageData, result) From cd4f16f1d0fb71f289a30e9d6b3dc3b88b72a8fd Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:44:53 -0400 Subject: [PATCH 10/16] fix: oops --- interactions/api/http/http_requests/messages.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/interactions/api/http/http_requests/messages.py b/interactions/api/http/http_requests/messages.py index 012b7fc0b..8891df54b 100644 --- a/interactions/api/http/http_requests/messages.py +++ b/interactions/api/http/http_requests/messages.py @@ -1,4 +1,4 @@ -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, cast, TypedDict import discord_typings @@ -14,6 +14,10 @@ from interactions import UPLOADABLE_TYPE +class GetAnswerVotersData(TypedDict): + users: list[discord_typings.UserData] + + class MessageRequests(CanRequest): async def create_message( self, @@ -184,7 +188,7 @@ async def get_answer_voters( answer_id: int, after: "Snowflake_Type | None" = None, limit: int = 25, - ) -> list[discord_typings.UserData]: + ) -> GetAnswerVotersData: """ Get a list of users that voted for this specific answer. @@ -196,20 +200,20 @@ async def get_answer_voters( limit: The max number of users to return (default 25, max 100) Returns: - list[discord_typings.UserData]: A list of users that voted for the answer + GetAnswerVotersData: A response that has a list of users that voted for the answer """ result = await self.request( Route( "GET", - "/channels/{channel_id}/messages/{message_id}/polls/{answer_id}/votes", + "/channels/{channel_id}/polls/{message_id}/answers/{answer_id}", channel_id=channel_id, message_id=message_id, answer_id=answer_id, ), params=dict_filter_none({"after": after, "limit": limit}), ) - return cast(list[discord_typings.UserData], result) + return cast(GetAnswerVotersData, result) async def end_poll(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type") -> discord_typings.MessageData: """ @@ -226,7 +230,7 @@ async def end_poll(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Ty result = await self.request( Route( "POST", - "/channels/{channel_id}/messages/{message_id}/polls/end", + "/channels/{channel_id}/polls/{message_id}/expire", channel_id=channel_id, message_id=message_id, ) From ae2b49e45088a951086385660b5b5347680515a9 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Fri, 7 Jun 2024 17:58:40 -0400 Subject: [PATCH 11/16] feat: add methods to interact with http methods --- interactions/models/discord/message.py | 53 ++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index ffc769018..67c343b4e 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -1,6 +1,7 @@ import asyncio import base64 import re +from collections import namedtuple from dataclasses import dataclass from typing import ( TYPE_CHECKING, @@ -29,6 +30,7 @@ from interactions.models.discord.emoji import process_emoji_req_format from interactions.models.discord.file import UPLOADABLE_TYPE from interactions.models.discord.poll import Poll +from interactions.models.misc.iterator import AsyncIterator from .base import DiscordObject from .enums import ( @@ -70,6 +72,35 @@ ) channel_mention = re.compile(r"<#(?P[0-9]{17,})>") + + +class PollAnswerVotersIterator(AsyncIterator): + def __init__( + self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None + ) -> None: + self.message: "Message" = message + self.answer_id = answer_id + self.after: Snowflake_Type | None = after + self._more: bool = True + super().__init__(limit) + + async def fetch(self) -> list["models.User"]: + if not self.last: + self.last = namedtuple("temp", "id") + self.last.id = self.after + + rcv = await self.message._client.http.get_answer_voters( + self.message._channel_id, + self.message.id, + self.answer_id, + limit=self.get_limit, + after=to_snowflake(self.last.id) if self.last.id else None, + ) + if not rcv: + raise asyncio.QueueEmpty + + users = [self.message._client.cache.place_user_data(user_data) for user_data in rcv["users"]] + return users @attrs.define(eq=False, order=False, hash=False, kw_only=True) @@ -408,6 +439,8 @@ class Message(BaseMessage): """Data showing the source of a crosspost, channel follow add, pin, or reply message""" flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags) """Message flags combined as a bitfield""" + poll: Optional[Poll] = attrs.field(repr=False, default=None, converter=optional_c(Poll.from_dict)) + """A poll.""" interaction_metadata: Optional[MessageInteractionMetadata] = attrs.field(repr=False, default=None) """Sent if the message is a response to an Interaction""" interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None) @@ -644,6 +677,20 @@ def jump_url(self) -> str: def proto_url(self) -> str: """A URL like `jump_url` that uses protocols.""" return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}" + + def answer_voters( + self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None + ) -> PollAnswerVotersIterator: + """ + An async iterator for getting the voters for an answer in the poll this message has. + + Args: + answer_id: The answer to get voters for + after: Get messages after this user ID + limit: The max number of users to return (default 25, max 100) + + """ + return PollAnswerVotersIterator(self, answer_id, limit, before) async def edit( self, @@ -900,6 +947,12 @@ async def publish(self) -> None: """ await self._client.http.crosspost_message(self._channel_id, self.id) + async def end_poll(self) -> "Message": + """Ends the poll contained in this message.""" + message_data = await self._client.http.end_poll(self._channel_id, self.id) + if message_data: + return self._client.cache.place_message_data(message_data) + def process_allowed_mentions(allowed_mentions: Optional[Union[AllowedMentions, dict]]) -> Optional[dict]: """ From f13f1119e552ddaf6c9b0e405c1a2fd0ed4a792c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 7 Jun 2024 21:59:14 +0000 Subject: [PATCH 12/16] ci: correct from checks. --- interactions/models/discord/message.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index 67c343b4e..a691de918 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -72,12 +72,10 @@ ) channel_mention = re.compile(r"<#(?P[0-9]{17,})>") - + class PollAnswerVotersIterator(AsyncIterator): - def __init__( - self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None - ) -> None: + def __init__(self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None) -> None: self.message: "Message" = message self.answer_id = answer_id self.after: Snowflake_Type | None = after @@ -677,10 +675,8 @@ def jump_url(self) -> str: def proto_url(self) -> str: """A URL like `jump_url` that uses protocols.""" return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}" - - def answer_voters( - self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None - ) -> PollAnswerVotersIterator: + + def answer_voters(self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None) -> PollAnswerVotersIterator: """ An async iterator for getting the voters for an answer in the poll this message has. From d742c7711a634e6f9db82239549f3cf0a9263e25 Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Sat, 8 Jun 2024 10:13:28 -0400 Subject: [PATCH 13/16] feat: add poll events --- interactions/api/events/__init__.py | 4 ++ interactions/api/events/discord.py | 69 +++++++++++++++++++ .../api/events/processors/message_events.py | 38 ++++++++++ 3 files changed, 111 insertions(+) diff --git a/interactions/api/events/__init__.py b/interactions/api/events/__init__.py index 70db34457..f208fb7c5 100644 --- a/interactions/api/events/__init__.py +++ b/interactions/api/events/__init__.py @@ -41,6 +41,8 @@ MessageCreate, MessageDelete, MessageDeleteBulk, + MessagePollVoteAdd, + MessagePollVoteRemove, MessageReactionAdd, MessageReactionRemove, MessageReactionRemoveAll, @@ -159,6 +161,8 @@ "MessageCreate", "MessageDelete", "MessageDeleteBulk", + "MessagePollVoteAdd", + "MessagePollVoteRemove", "MessageReactionAdd", "MessageReactionRemove", "MessageReactionRemoveAll", diff --git a/interactions/api/events/discord.py b/interactions/api/events/discord.py index bbd6f1f2a..181ec1eb0 100644 --- a/interactions/api/events/discord.py +++ b/interactions/api/events/discord.py @@ -72,6 +72,8 @@ async def an_event_handler(event: ChannelCreate): "MessageCreate", "MessageDelete", "MessageDeleteBulk", + "MessagePollVoteAdd", + "MessagePollVoteRemove", "MessageReactionAdd", "MessageReactionRemove", "MessageReactionRemoveAll", @@ -115,6 +117,7 @@ async def an_event_handler(event: ChannelCreate): from interactions.models.discord.entitlement import Entitlement from interactions.models.discord.guild import Guild, GuildIntegration from interactions.models.discord.message import Message + from interactions.models.discord.poll import Poll from interactions.models.discord.reaction import Reaction from interactions.models.discord.role import Role from interactions.models.discord.scheduled_event import ScheduledEvent @@ -588,6 +591,72 @@ class MessageReactionRemoveEmoji(MessageReactionRemoveAll): """The emoji that was removed""" +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class BaseMessagePollEvent(BaseEvent): + user_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the user that voted""" + channel_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the channel the poll is in""" + message_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the message the poll is in""" + answer_id: int = attrs.field(repr=False) + """The ID of the answer the user voted for""" + guild_id: "Optional[Snowflake_Type]" = attrs.field(repr=False, default=None) + """The ID of the guild the poll is in""" + + def get_message(self) -> "Optional[Message]": + """Get the message object if it is cached""" + return self.client.cache.get_message(self.channel_id, self.message_id) + + def get_user(self) -> "Optional[User]": + """Get the user object if it is cached""" + return self.client.get_user(self.user_id) + + def get_channel(self) -> "Optional[TYPE_ALL_CHANNEL]": + """Get the channel object if it is cached""" + return self.client.get_channel(self.channel_id) + + def get_guild(self) -> "Optional[Guild]": + """Get the guild object if it is cached""" + return self.client.get_guild(self.guild_id) if self.guild_id is not None else None + + def get_poll(self) -> "Optional[Poll]": + """Get the poll object if it is cached""" + message = self.get_message() + return message.poll if message is not None else None + + async def fetch_message(self) -> "Message": + """Fetch the message the poll is in""" + return await self.client.cache.fetch_message(self.channel_id, self.message_id) + + async def fetch_user(self) -> "User": + """Fetch the user that voted""" + return await self.client.fetch_user(self.user_id) + + async def fetch_channel(self) -> "TYPE_ALL_CHANNEL": + """Fetch the channel the poll is in""" + return await self.client.fetch_channel(self.channel_id) + + async def fetch_guild(self) -> "Optional[Guild]": + """Fetch the guild the poll is in""" + return await self.client.fetch_guild(self.guild_id) if self.guild_id is not None else None + + async def fetch_poll(self) -> "Poll": + """Fetch the poll object""" + message = await self.fetch_message() + return message.poll + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessagePollVoteAdd(BaseMessagePollEvent): + """Dispatched when a user votes in a poll""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessagePollVoteRemove(BaseMessagePollEvent): + """Dispatched when a user remotes a votes in a poll""" + + @attrs.define(eq=False, order=False, hash=False, kw_only=False) class PresenceUpdate(BaseEvent): """A user's presence has changed.""" diff --git a/interactions/api/events/processors/message_events.py b/interactions/api/events/processors/message_events.py index 00314464f..74a10cbde 100644 --- a/interactions/api/events/processors/message_events.py +++ b/interactions/api/events/processors/message_events.py @@ -83,3 +83,41 @@ async def _on_raw_message_delete_bulk(self, event: "RawGatewayEvent") -> None: event.data.get("ids"), ) ) + + @Processor.define() + async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None: + """ + Process raw message poll vote add event and dispatch a processed poll vote add event. + + Args: + event: raw poll vote add event + + """ + self.dispatch( + events.MessagePollVoteAdd( + event.data.get("guild_id", None), + event.data["channel_id"], + event.data["message_id"], + event.data["user_id"], + event.data["option"], + ) + ) + + @Processor.define() + async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> None: + """ + Process raw message poll vote remove event and dispatch a processed poll vote remove event. + + Args: + event: raw poll vote remove event + + """ + self.dispatch( + events.MessagePollVoteRemove( + event.data.get("guild_id", None), + event.data["channel_id"], + event.data["message_id"], + event.data["user_id"], + event.data["option"], + ) + ) From 582c9bd82a4985098732fa889bb03d47d6826a5f Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Sat, 8 Jun 2024 10:14:41 -0400 Subject: [PATCH 14/16] feat: add send polls permission --- interactions/models/discord/enums.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/interactions/models/discord/enums.py b/interactions/models/discord/enums.py index 166f0fef7..ea0b66940 100644 --- a/interactions/models/discord/enums.py +++ b/interactions/models/discord/enums.py @@ -581,6 +581,8 @@ class Permissions(DiscordIntFlag): # type: ignore """Allows the usage of custom sounds from other servers""" SEND_VOICE_MESSAGES = 1 << 46 """Allows for sending audio messages""" + SEND_POLLS = 1 << 49 + """Allows sending polls""" # Shortcuts/grouping/aliases REQUIRES_MFA = ( From 868d1263228d82f680c3040e34ffd0eb47f49daf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 8 Jun 2024 14:42:31 +0000 Subject: [PATCH 15/16] ci: correct from checks. --- interactions/models/discord/message.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index a691de918..8f664c3ad 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -75,7 +75,9 @@ class PollAnswerVotersIterator(AsyncIterator): - def __init__(self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None) -> None: + def __init__( + self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None + ) -> None: self.message: "Message" = message self.answer_id = answer_id self.after: Snowflake_Type | None = after @@ -676,7 +678,9 @@ def proto_url(self) -> str: """A URL like `jump_url` that uses protocols.""" return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}" - def answer_voters(self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None) -> PollAnswerVotersIterator: + def answer_voters( + self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None + ) -> PollAnswerVotersIterator: """ An async iterator for getting the voters for an answer in the poll this message has. From 4ac0299491b6efa1bd55c27fac29498c3d2b644a Mon Sep 17 00:00:00 2001 From: AstreaTSS <25420078+AstreaTSS@users.noreply.github.com> Date: Sun, 9 Jun 2024 18:34:17 -0400 Subject: [PATCH 16/16] docs: clarify weirdness of text property --- interactions/models/discord/poll.py | 16 +++++++++++++++- 1 file changed, 15 insertions(+), 1 deletion(-) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py index e61a1f157..1656d7e23 100644 --- a/interactions/models/discord/poll.py +++ b/interactions/models/discord/poll.py @@ -26,7 +26,13 @@ @attrs.define(eq=False, order=False, hash=False, kw_only=True) class PollMedia(DictSerializationMixin): text: Optional[str] = attrs.field(repr=False, default=None) - """The text of the field.""" + """ + The text of the field. + + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + """ emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict)) """The emoji of the field.""" @@ -35,6 +41,10 @@ def create(cls, *, text: Optional[str] = None, emoji: Optional[Union[PartialEmoj """ Create a PollMedia object, used for questions and answers for polls. + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + Args: text: The text of the field. emoji: The emoji of the field. @@ -150,6 +160,10 @@ def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEm """ Adds an answer to the poll. + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + Args: text: The text of the answer. emoji: The emoji for the answer.