diff --git a/docs/src/API Reference/API Reference/models/Discord/index.md b/docs/src/API Reference/API Reference/models/Discord/index.md index 1b45617cd..b97fffa0e 100644 --- a/docs/src/API Reference/API Reference/models/Discord/index.md +++ b/docs/src/API Reference/API Reference/models/Discord/index.md @@ -22,6 +22,7 @@ search: - [Invite](invite) - [Message](message) - [Modals](modals) +- [Poll](poll) - [Reaction](reaction) - [Role](role) - [Scheduled event](scheduled_event) diff --git a/docs/src/API Reference/API Reference/models/Discord/poll.md b/docs/src/API Reference/API Reference/models/Discord/poll.md new file mode 100644 index 000000000..7f53c7c3d --- /dev/null +++ b/docs/src/API Reference/API Reference/models/Discord/poll.md @@ -0,0 +1 @@ +::: interactions.models.discord.poll diff --git a/interactions/__init__.py b/interactions/__init__.py index 0d7ee034d..c20006619 100644 --- a/interactions/__init__.py +++ b/interactions/__init__.py @@ -25,6 +25,8 @@ MentionPrefix, Missing, MISSING, + POLL_MAX_ANSWERS, + POLL_MAX_DURATION_HOURS, PREMIUM_GUILD_LIMITS, SELECT_MAX_NAME_LENGTH, SELECTS_MAX_OPTIONS, @@ -243,6 +245,12 @@ PartialEmojiConverter, PermissionOverwrite, Permissions, + Poll, + PollAnswer, + PollAnswerCount, + PollLayoutType, + PollMedia, + PollResults, PremiumTier, PremiumType, process_allowed_mentions, @@ -594,6 +602,14 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", + "PollLayoutType", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", + "PollMedia", + "PollResults", "PREMIUM_GUILD_LIMITS", "PremiumTier", "PremiumType", diff --git a/interactions/api/events/__init__.py b/interactions/api/events/__init__.py index 70db34457..f208fb7c5 100644 --- a/interactions/api/events/__init__.py +++ b/interactions/api/events/__init__.py @@ -41,6 +41,8 @@ MessageCreate, MessageDelete, MessageDeleteBulk, + MessagePollVoteAdd, + MessagePollVoteRemove, MessageReactionAdd, MessageReactionRemove, MessageReactionRemoveAll, @@ -159,6 +161,8 @@ "MessageCreate", "MessageDelete", "MessageDeleteBulk", + "MessagePollVoteAdd", + "MessagePollVoteRemove", "MessageReactionAdd", "MessageReactionRemove", "MessageReactionRemoveAll", diff --git a/interactions/api/events/discord.py b/interactions/api/events/discord.py index bbd6f1f2a..181ec1eb0 100644 --- a/interactions/api/events/discord.py +++ b/interactions/api/events/discord.py @@ -72,6 +72,8 @@ async def an_event_handler(event: ChannelCreate): "MessageCreate", "MessageDelete", "MessageDeleteBulk", + "MessagePollVoteAdd", + "MessagePollVoteRemove", "MessageReactionAdd", "MessageReactionRemove", "MessageReactionRemoveAll", @@ -115,6 +117,7 @@ async def an_event_handler(event: ChannelCreate): from interactions.models.discord.entitlement import Entitlement from interactions.models.discord.guild import Guild, GuildIntegration from interactions.models.discord.message import Message + from interactions.models.discord.poll import Poll from interactions.models.discord.reaction import Reaction from interactions.models.discord.role import Role from interactions.models.discord.scheduled_event import ScheduledEvent @@ -588,6 +591,72 @@ class MessageReactionRemoveEmoji(MessageReactionRemoveAll): """The emoji that was removed""" +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class BaseMessagePollEvent(BaseEvent): + user_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the user that voted""" + channel_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the channel the poll is in""" + message_id: "Snowflake_Type" = attrs.field(repr=False) + """The ID of the message the poll is in""" + answer_id: int = attrs.field(repr=False) + """The ID of the answer the user voted for""" + guild_id: "Optional[Snowflake_Type]" = attrs.field(repr=False, default=None) + """The ID of the guild the poll is in""" + + def get_message(self) -> "Optional[Message]": + """Get the message object if it is cached""" + return self.client.cache.get_message(self.channel_id, self.message_id) + + def get_user(self) -> "Optional[User]": + """Get the user object if it is cached""" + return self.client.get_user(self.user_id) + + def get_channel(self) -> "Optional[TYPE_ALL_CHANNEL]": + """Get the channel object if it is cached""" + return self.client.get_channel(self.channel_id) + + def get_guild(self) -> "Optional[Guild]": + """Get the guild object if it is cached""" + return self.client.get_guild(self.guild_id) if self.guild_id is not None else None + + def get_poll(self) -> "Optional[Poll]": + """Get the poll object if it is cached""" + message = self.get_message() + return message.poll if message is not None else None + + async def fetch_message(self) -> "Message": + """Fetch the message the poll is in""" + return await self.client.cache.fetch_message(self.channel_id, self.message_id) + + async def fetch_user(self) -> "User": + """Fetch the user that voted""" + return await self.client.fetch_user(self.user_id) + + async def fetch_channel(self) -> "TYPE_ALL_CHANNEL": + """Fetch the channel the poll is in""" + return await self.client.fetch_channel(self.channel_id) + + async def fetch_guild(self) -> "Optional[Guild]": + """Fetch the guild the poll is in""" + return await self.client.fetch_guild(self.guild_id) if self.guild_id is not None else None + + async def fetch_poll(self) -> "Poll": + """Fetch the poll object""" + message = await self.fetch_message() + return message.poll + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessagePollVoteAdd(BaseMessagePollEvent): + """Dispatched when a user votes in a poll""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=False) +class MessagePollVoteRemove(BaseMessagePollEvent): + """Dispatched when a user remotes a votes in a poll""" + + @attrs.define(eq=False, order=False, hash=False, kw_only=False) class PresenceUpdate(BaseEvent): """A user's presence has changed.""" diff --git a/interactions/api/events/processors/message_events.py b/interactions/api/events/processors/message_events.py index 00314464f..74a10cbde 100644 --- a/interactions/api/events/processors/message_events.py +++ b/interactions/api/events/processors/message_events.py @@ -83,3 +83,41 @@ async def _on_raw_message_delete_bulk(self, event: "RawGatewayEvent") -> None: event.data.get("ids"), ) ) + + @Processor.define() + async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None: + """ + Process raw message poll vote add event and dispatch a processed poll vote add event. + + Args: + event: raw poll vote add event + + """ + self.dispatch( + events.MessagePollVoteAdd( + event.data.get("guild_id", None), + event.data["channel_id"], + event.data["message_id"], + event.data["user_id"], + event.data["option"], + ) + ) + + @Processor.define() + async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> None: + """ + Process raw message poll vote remove event and dispatch a processed poll vote remove event. + + Args: + event: raw poll vote remove event + + """ + self.dispatch( + events.MessagePollVoteRemove( + event.data.get("guild_id", None), + event.data["channel_id"], + event.data["message_id"], + event.data["user_id"], + event.data["option"], + ) + ) diff --git a/interactions/api/http/http_requests/messages.py b/interactions/api/http/http_requests/messages.py index a836e269f..8891df54b 100644 --- a/interactions/api/http/http_requests/messages.py +++ b/interactions/api/http/http_requests/messages.py @@ -1,8 +1,9 @@ -from typing import TYPE_CHECKING, cast +from typing import TYPE_CHECKING, cast, TypedDict import discord_typings from interactions.models.internal.protocols import CanRequest +from interactions.client.utils.serializer import dict_filter_none from ..route import Route __all__ = ("MessageRequests",) @@ -13,6 +14,10 @@ from interactions import UPLOADABLE_TYPE +class GetAnswerVotersData(TypedDict): + users: list[discord_typings.UserData] + + class MessageRequests(CanRequest): async def create_message( self, @@ -175,3 +180,59 @@ async def crosspost_message( ) ) return cast(discord_typings.MessageData, result) + + async def get_answer_voters( + self, + channel_id: "Snowflake_Type", + message_id: "Snowflake_Type", + answer_id: int, + after: "Snowflake_Type | None" = None, + limit: int = 25, + ) -> GetAnswerVotersData: + """ + Get a list of users that voted for this specific answer. + + Args: + channel_id: Channel the message is in + message_id: The message with the poll + answer_id: The answer to get voters for + after: Get messages after this user ID + limit: The max number of users to return (default 25, max 100) + + Returns: + GetAnswerVotersData: A response that has a list of users that voted for the answer + + """ + result = await self.request( + Route( + "GET", + "/channels/{channel_id}/polls/{message_id}/answers/{answer_id}", + channel_id=channel_id, + message_id=message_id, + answer_id=answer_id, + ), + params=dict_filter_none({"after": after, "limit": limit}), + ) + return cast(GetAnswerVotersData, result) + + async def end_poll(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type") -> discord_typings.MessageData: + """ + Ends a poll. Only can end polls from the current bot. + + Args: + channel_id: Channel the message is in + message_id: The message with the poll + + Returns: + message object + + """ + result = await self.request( + Route( + "POST", + "/channels/{channel_id}/polls/{message_id}/expire", + channel_id=channel_id, + message_id=message_id, + ) + ) + return cast(discord_typings.MessageData, result) diff --git a/interactions/client/__init__.py b/interactions/client/__init__.py index 96df96f0a..fea5b314c 100644 --- a/interactions/client/__init__.py +++ b/interactions/client/__init__.py @@ -29,6 +29,8 @@ MISSING, MENTION_PREFIX, PREMIUM_GUILD_LIMITS, + POLL_MAX_ANSWERS, + POLL_MAX_DURATION_HOURS, Absent, T, T_co, @@ -61,6 +63,8 @@ "EMBED_MAX_FIELDS", "EMBED_TOTAL_MAX", "EMBED_FIELD_VALUE_LENGTH", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", "Singleton", "Sentinel", "GlobalScope", diff --git a/interactions/client/const.py b/interactions/client/const.py index d1b4f04e7..a1e525a55 100644 --- a/interactions/client/const.py +++ b/interactions/client/const.py @@ -84,6 +84,8 @@ "NON_RESUMABLE_WEBSOCKET_CLOSE_CODES", "CLIENT_FEATURE_FLAGS", "has_client_feature", + "POLL_MAX_ANSWERS", + "POLL_MAX_DURATION_HOURS", ) _ver_info = sys.version_info @@ -130,6 +132,9 @@ def get_logger() -> logging.Logger: EMBED_TOTAL_MAX = 6000 EMBED_FIELD_VALUE_LENGTH = 1024 +POLL_MAX_ANSWERS = 10 +POLL_MAX_DURATION_HOURS = 168 + class Singleton(type): _instances: ClassVar[dict] = {} diff --git a/interactions/client/mixins/send.py b/interactions/client/mixins/send.py index 5776c78d5..7aeadc03a 100644 --- a/interactions/client/mixins/send.py +++ b/interactions/client/mixins/send.py @@ -10,6 +10,7 @@ from interactions.models.discord.components import BaseComponent from interactions.models.discord.embed import Embed from interactions.models.discord.message import AllowedMentions, Message, MessageReference + from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker from interactions.models.discord.snowflake import Snowflake_Type @@ -49,6 +50,7 @@ async def send( delete_after: Optional[float] = None, nonce: Optional[str | int] = None, enforce_nonce: bool = False, + poll: "Optional[Poll | dict]" = None, **kwargs: Any, ) -> "Message": """ @@ -73,6 +75,7 @@ async def send( enforce_nonce: If enabled and nonce is present, it will be checked for uniqueness in the past few minutes. \ If another message was created by the same author with the same nonce, that message will be returned \ and no new message will be created. + poll: A poll. Returns: New message object that was sent. @@ -115,6 +118,7 @@ async def send( flags=flags, nonce=nonce, enforce_nonce=enforce_nonce, + poll=poll, **kwargs, ) diff --git a/interactions/ext/hybrid_commands/context.py b/interactions/ext/hybrid_commands/context.py index 195f83bca..70fdac6c7 100644 --- a/interactions/ext/hybrid_commands/context.py +++ b/interactions/ext/hybrid_commands/context.py @@ -23,6 +23,7 @@ Attachment, process_message_payload, TYPE_MESSAGEABLE_CHANNEL, + Poll, ) from interactions.models.discord.enums import ContextType from interactions.client.mixins.send import SendMixin @@ -309,6 +310,7 @@ async def send( suppress_embeds: bool = False, silent: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, + poll: "Optional[Poll | dict]" = None, delete_after: Optional[float] = None, ephemeral: bool = False, **kwargs: Any, @@ -330,6 +332,7 @@ async def send( suppress_embeds: Should embeds be suppressed on this send silent: Should this message be sent without triggering a notification. flags: Message flags to apply. + poll: A poll. delete_after: Delete message after this many seconds. ephemeral: Should this message be sent as ephemeral (hidden) - only works with interactions @@ -358,6 +361,7 @@ async def send( file=file, tts=tts, flags=flags, + poll=poll, delete_after=delete_after, pass_self_into_delete=bool(self._slash_ctx), **kwargs, diff --git a/interactions/models/__init__.py b/interactions/models/__init__.py index 6cd71f547..153ba642f 100644 --- a/interactions/models/__init__.py +++ b/interactions/models/__init__.py @@ -122,6 +122,12 @@ PartialEmoji, PermissionOverwrite, Permissions, + Poll, + PollAnswer, + PollAnswerCount, + PollLayoutType, + PollMedia, + PollResults, PremiumTier, PremiumType, process_allowed_mentions, @@ -523,6 +529,12 @@ "PartialEmojiConverter", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", + "PollLayoutType", + "PollMedia", + "PollResults", "PremiumTier", "PremiumType", "process_allowed_mentions", diff --git a/interactions/models/discord/__init__.py b/interactions/models/discord/__init__.py index 91dec1218..c34b6499d 100644 --- a/interactions/models/discord/__init__.py +++ b/interactions/models/discord/__init__.py @@ -104,6 +104,7 @@ OnboardingPromptType, OverwriteType, Permissions, + PollLayoutType, PremiumTier, PremiumType, ScheduledEventPrivacyLevel, @@ -155,6 +156,7 @@ ) from .modal import InputText, Modal, ParagraphText, ShortText, TextStyles from .onboarding import Onboarding, OnboardingPrompt, OnboardingPromptOption +from .poll import PollMedia, PollAnswer, PollAnswerCount, PollResults, Poll from .reaction import Reaction, ReactionUsers from .role import Role from .scheduled_event import ScheduledEvent @@ -299,6 +301,12 @@ "PartialEmoji", "PermissionOverwrite", "Permissions", + "Poll", + "PollAnswer", + "PollAnswerCount", + "PollLayoutType", + "PollMedia", + "PollResults", "PremiumTier", "PremiumType", "process_allowed_mentions", diff --git a/interactions/models/discord/enums.py b/interactions/models/discord/enums.py index 0c25b4eeb..ea0b66940 100644 --- a/interactions/models/discord/enums.py +++ b/interactions/models/discord/enums.py @@ -38,6 +38,7 @@ "OnboardingPromptType", "OverwriteType", "Permissions", + "PollLayoutType", "PremiumTier", "PremiumType", "ScheduledEventPrivacyLevel", @@ -580,6 +581,8 @@ class Permissions(DiscordIntFlag): # type: ignore """Allows the usage of custom sounds from other servers""" SEND_VOICE_MESSAGES = 1 << 46 """Allows for sending audio messages""" + SEND_POLLS = 1 << 49 + """Allows sending polls""" # Shortcuts/grouping/aliases REQUIRES_MFA = ( @@ -1134,3 +1137,9 @@ class EntitlementType(CursedIntEnum): """Entitlement was claimed by user for free as a Nitro Subscriber""" APPLICATION_SUBSCRIPTION = 8 """Entitlement was purchased as an app subscription""" + + +class PollLayoutType(CursedIntEnum): + """The layout of a poll.""" + + DEFAULT = 1 diff --git a/interactions/models/discord/message.py b/interactions/models/discord/message.py index a214b8800..8f664c3ad 100644 --- a/interactions/models/discord/message.py +++ b/interactions/models/discord/message.py @@ -1,6 +1,7 @@ import asyncio import base64 import re +from collections import namedtuple from dataclasses import dataclass from typing import ( TYPE_CHECKING, @@ -28,6 +29,8 @@ from interactions.models.discord.embed import process_embeds from interactions.models.discord.emoji import process_emoji_req_format from interactions.models.discord.file import UPLOADABLE_TYPE +from interactions.models.discord.poll import Poll +from interactions.models.misc.iterator import AsyncIterator from .base import DiscordObject from .enums import ( @@ -71,6 +74,35 @@ channel_mention = re.compile(r"<#(?P[0-9]{17,})>") +class PollAnswerVotersIterator(AsyncIterator): + def __init__( + self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None + ) -> None: + self.message: "Message" = message + self.answer_id = answer_id + self.after: Snowflake_Type | None = after + self._more: bool = True + super().__init__(limit) + + async def fetch(self) -> list["models.User"]: + if not self.last: + self.last = namedtuple("temp", "id") + self.last.id = self.after + + rcv = await self.message._client.http.get_answer_voters( + self.message._channel_id, + self.message.id, + self.answer_id, + limit=self.get_limit, + after=to_snowflake(self.last.id) if self.last.id else None, + ) + if not rcv: + raise asyncio.QueueEmpty + + users = [self.message._client.cache.place_user_data(user_data) for user_data in rcv["users"]] + return users + + @attrs.define(eq=False, order=False, hash=False, kw_only=True) class Attachment(DiscordObject): filename: str = attrs.field( @@ -407,6 +439,8 @@ class Message(BaseMessage): """Data showing the source of a crosspost, channel follow add, pin, or reply message""" flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags) """Message flags combined as a bitfield""" + poll: Optional[Poll] = attrs.field(repr=False, default=None, converter=optional_c(Poll.from_dict)) + """A poll.""" interaction_metadata: Optional[MessageInteractionMetadata] = attrs.field(repr=False, default=None) """Sent if the message is a response to an Interaction""" interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None) @@ -644,6 +678,20 @@ def proto_url(self) -> str: """A URL like `jump_url` that uses protocols.""" return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}" + def answer_voters( + self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None + ) -> PollAnswerVotersIterator: + """ + An async iterator for getting the voters for an answer in the poll this message has. + + Args: + answer_id: The answer to get voters for + after: Get messages after this user ID + limit: The max number of users to return (default 25, max 100) + + """ + return PollAnswerVotersIterator(self, answer_id, limit, before) + async def edit( self, *, @@ -899,6 +947,12 @@ async def publish(self) -> None: """ await self._client.http.crosspost_message(self._channel_id, self.id) + async def end_poll(self) -> "Message": + """Ends the poll contained in this message.""" + message_data = await self._client.http.end_poll(self._channel_id, self.id) + if message_data: + return self._client.cache.place_message_data(message_data) + def process_allowed_mentions(allowed_mentions: Optional[Union[AllowedMentions, dict]]) -> Optional[dict]: """ @@ -981,6 +1035,7 @@ def process_message_payload( flags: Optional[Union[int, MessageFlags]] = None, nonce: Optional[str | int] = None, enforce_nonce: bool = False, + poll: Optional[Poll | dict] = None, **kwargs, ) -> dict: """ @@ -1000,6 +1055,7 @@ def process_message_payload( enforce_nonce: If enabled and nonce is present, it will be checked for uniqueness in the past few minutes. \ If another message was created by the same author with the same nonce, that message will be returned \ and no new message will be created. + poll: A poll. Returns: Dictionary @@ -1017,6 +1073,9 @@ def process_message_payload( if attachments: attachments = [attachment.to_dict() for attachment in attachments] + if isinstance(poll, Poll): + poll = poll.to_dict() + return dict_filter_none( { "content": content, @@ -1030,6 +1089,7 @@ def process_message_payload( "flags": flags, "nonce": nonce, "enforce_nonce": enforce_nonce, + "poll": poll, **kwargs, } ) diff --git a/interactions/models/discord/poll.py b/interactions/models/discord/poll.py new file mode 100644 index 000000000..1656d7e23 --- /dev/null +++ b/interactions/models/discord/poll.py @@ -0,0 +1,184 @@ +from typing import Optional, Union, Dict, Any +from typing_extensions import Self + +import attrs + +from interactions.client.const import MISSING, POLL_MAX_DURATION_HOURS, POLL_MAX_ANSWERS +from interactions.client.utils.attr_converters import ( + optional, + timestamp_converter, +) +from interactions.client.mixins.serialization import DictSerializationMixin +from interactions.client.utils.serializer import no_export_meta +from interactions.models.discord.emoji import PartialEmoji, process_emoji +from interactions.models.discord.enums import PollLayoutType +from interactions.models.discord.timestamp import Timestamp + +__all__ = ( + "PollMedia", + "PollAnswer", + "PollAnswerCount", + "PollResults", + "Poll", +) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollMedia(DictSerializationMixin): + text: Optional[str] = attrs.field(repr=False, default=None) + """ + The text of the field. + + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + """ + emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict)) + """The emoji of the field.""" + + @classmethod + def create(cls, *, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self: + """ + Create a PollMedia object, used for questions and answers for polls. + + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + + Args: + text: The text of the field. + emoji: The emoji of the field. + + Returns: + A PollMedia object. + + """ + if not text and not emoji: + raise ValueError("Either text or emoji must be provided.") + + return cls(text=text, emoji=process_emoji(emoji)) + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollAnswer(DictSerializationMixin): + poll_media: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict) + """The data of the answer.""" + answer_id: Optional[int] = attrs.field(repr=False, default=None) + """The ID of the answer. This is only returned for polls that have been given by Discord's API.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollAnswerCount(DictSerializationMixin): + id: int = attrs.field(repr=False) + """The answer ID of the answer.""" + count: int = attrs.field(repr=False, default=0) + """The number of votes for this answer.""" + me_voted: bool = attrs.field(repr=False, default=False) + """Whether the current user voted for this answer.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class PollResults(DictSerializationMixin): + is_finalized: bool = attrs.field(repr=False, default=False) + """Whether the votes have been precisely counted.""" + answer_counts: list[PollAnswerCount] = attrs.field(repr=False, factory=list, converter=PollAnswerCount.from_list) + """The counts for each answer.""" + + +@attrs.define(eq=False, order=False, hash=False, kw_only=True) +class Poll(DictSerializationMixin): + question: PollMedia = attrs.field(repr=False) + """The question of the poll. Only text media is supported.""" + answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list) + """Each of the answers available in the poll, up to 10.""" + expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=optional(timestamp_converter)) + """Number of hours the poll is open for, up to 7 days.""" + allow_multiselect: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) + """Whether a user can select multiple answers.""" + layout_type: PollLayoutType = attrs.field(repr=False, default=PollLayoutType.DEFAULT, converter=PollLayoutType) + """The layout type of the poll.""" + results: Optional[PollResults] = attrs.field(repr=False, default=None, converter=optional(PollResults.from_dict)) + """The results of the poll, if the polls is finished.""" + + _duration: int = attrs.field(repr=False, default=0) + """How long, in hours, the poll will be open for (up to 7 days). This is only used when creating polls.""" + + @classmethod + def create( + cls, + question: str, + *, + duration: int, + allow_multiselect: bool = False, + answers: Optional[list[PollMedia | str]] = None, + ) -> Self: + """ + Create a Poll object for sending. + + Args: + question: The question of the poll. + duration: How long, in hours, the poll will be open for (up to 7 days). + allow_multiselect: Whether a user can select multiple answers. + answers: Each of the answers available in the poll, up to 10. + + Returns: + A Poll object. + + """ + if answers: + media_to_answers = [ + ( + PollAnswer(poll_media=answer) + if isinstance(answer, PollMedia) + else PollAnswer(poll_media=PollMedia.create(text=answer)) + ) + for answer in answers + ] + else: + media_to_answers = [] + + return cls( + question=PollMedia(text=question), + duration=duration, + allow_multiselect=allow_multiselect, + answers=media_to_answers, + ) + + @answers.validator + def _answers_validation(self, attribute: str, value: Any) -> None: + if len(value) > POLL_MAX_ANSWERS: + raise ValueError(f"A poll can have at most {POLL_MAX_ANSWERS} answers.") + + @_duration.validator + def _duration_validation(self, attribute: str, value: int) -> None: + if value < 0 or value > POLL_MAX_DURATION_HOURS: + raise ValueError( + f"The duration must be between 0 and {POLL_MAX_DURATION_HOURS} hours ({POLL_MAX_DURATION_HOURS // 24} days)." + ) + + def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self: + """ + Adds an answer to the poll. + + !!! warning + While `text` is *marked* as optional, it is *currently required* by Discord's API to make polls. + According to Discord, this may change to be actually optional in the future. + + Args: + text: The text of the answer. + emoji: The emoji for the answer. + + """ + if not text and not emoji: + raise ValueError("Either text or emoji must be provided") + + self.answers.append(PollAnswer(poll_media=PollMedia.create(text=text, emoji=emoji))) + self._answers_validation("answers", self.answers) + return self + + def to_dict(self) -> Dict[str, Any]: + data = super().to_dict() + + data["duration"] = self._duration + data.pop("_duration", None) + return data diff --git a/interactions/models/discord/webhooks.py b/interactions/models/discord/webhooks.py index f474662fd..548eaec9f 100644 --- a/interactions/models/discord/webhooks.py +++ b/interactions/models/discord/webhooks.py @@ -27,6 +27,7 @@ Message, MessageReference, ) + from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker __all__ = ("WebhookTypes", "Webhook") @@ -190,6 +191,7 @@ async def send( tts: bool = False, suppress_embeds: bool = False, flags: Optional[Union[int, "MessageFlags"]] = None, + poll: "Optional[Poll | dict]" = None, username: str | None = None, avatar_url: str | None = None, wait: bool = False, @@ -212,6 +214,7 @@ async def send( tts: Should this message use Text To Speech. suppress_embeds: Should embeds be suppressed on this send flags: Message flags to apply. + poll: A poll. username: The username to use avatar_url: The url of an image to use as the avatar wait: Waits for confirmation of delivery. Set this to True if you intend to edit the message @@ -241,6 +244,7 @@ async def send( reply_to=reply_to, tts=tts, flags=flags, + poll=poll, username=username, avatar_url=avatar_url, **kwargs, diff --git a/interactions/models/internal/context.py b/interactions/models/internal/context.py index 1f826627f..859161bb3 100644 --- a/interactions/models/internal/context.py +++ b/interactions/models/internal/context.py @@ -12,6 +12,7 @@ from interactions.client.const import get_logger, MISSING from interactions.models.discord.components import BaseComponent from interactions.models.discord.file import UPLOADABLE_TYPE +from interactions.models.discord.poll import Poll from interactions.models.discord.sticker import Sticker from interactions.models.discord.user import Member, User @@ -540,6 +541,7 @@ async def send( suppress_embeds: bool = False, silent: bool = False, flags: typing.Optional[typing.Union[int, "MessageFlags"]] = None, + poll: "typing.Optional[Poll | dict]" = None, delete_after: typing.Optional[float] = None, ephemeral: bool = False, **kwargs: typing.Any, @@ -561,6 +563,7 @@ async def send( suppress_embeds: Should embeds be suppressed on this send silent: Should this message be sent without triggering a notification. flags: Message flags to apply. + poll: A poll. delete_after: Delete message after this many seconds. ephemeral: Whether the response should be ephemeral @@ -589,6 +592,7 @@ async def send( file=file, tts=tts, flags=flags, + poll=poll, delete_after=delete_after, pass_self_into_delete=True, **kwargs,