Skip to content

feat: add polls #1691

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ search:
- [Invite](invite)
- [Message](message)
- [Modals](modals)
- [Poll](poll)
- [Reaction](reaction)
- [Role](role)
- [Scheduled event](scheduled_event)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: interactions.models.discord.poll
16 changes: 16 additions & 0 deletions interactions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
MentionPrefix,
Missing,
MISSING,
POLL_MAX_ANSWERS,
POLL_MAX_DURATION_HOURS,
PREMIUM_GUILD_LIMITS,
SELECT_MAX_NAME_LENGTH,
SELECTS_MAX_OPTIONS,
Expand Down Expand Up @@ -243,6 +245,12 @@
PartialEmojiConverter,
PermissionOverwrite,
Permissions,
Poll,
PollAnswer,
PollAnswerCount,
PollLayoutType,
PollMedia,
PollResults,
PremiumTier,
PremiumType,
process_allowed_mentions,
Expand Down Expand Up @@ -594,6 +602,14 @@
"PartialEmojiConverter",
"PermissionOverwrite",
"Permissions",
"Poll",
"PollAnswer",
"PollAnswerCount",
"PollLayoutType",
"POLL_MAX_ANSWERS",
"POLL_MAX_DURATION_HOURS",
"PollMedia",
"PollResults",
"PREMIUM_GUILD_LIMITS",
"PremiumTier",
"PremiumType",
Expand Down
4 changes: 4 additions & 0 deletions interactions/api/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
MessageCreate,
MessageDelete,
MessageDeleteBulk,
MessagePollVoteAdd,
MessagePollVoteRemove,
MessageReactionAdd,
MessageReactionRemove,
MessageReactionRemoveAll,
Expand Down Expand Up @@ -159,6 +161,8 @@
"MessageCreate",
"MessageDelete",
"MessageDeleteBulk",
"MessagePollVoteAdd",
"MessagePollVoteRemove",
"MessageReactionAdd",
"MessageReactionRemove",
"MessageReactionRemoveAll",
Expand Down
69 changes: 69 additions & 0 deletions interactions/api/events/discord.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ async def an_event_handler(event: ChannelCreate):
"MessageCreate",
"MessageDelete",
"MessageDeleteBulk",
"MessagePollVoteAdd",
"MessagePollVoteRemove",
"MessageReactionAdd",
"MessageReactionRemove",
"MessageReactionRemoveAll",
Expand Down Expand Up @@ -115,6 +117,7 @@ async def an_event_handler(event: ChannelCreate):
from interactions.models.discord.entitlement import Entitlement
from interactions.models.discord.guild import Guild, GuildIntegration
from interactions.models.discord.message import Message
from interactions.models.discord.poll import Poll
from interactions.models.discord.reaction import Reaction
from interactions.models.discord.role import Role
from interactions.models.discord.scheduled_event import ScheduledEvent
Expand Down Expand Up @@ -588,6 +591,72 @@ class MessageReactionRemoveEmoji(MessageReactionRemoveAll):
"""The emoji that was removed"""


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class BaseMessagePollEvent(BaseEvent):
user_id: "Snowflake_Type" = attrs.field(repr=False)
"""The ID of the user that voted"""
channel_id: "Snowflake_Type" = attrs.field(repr=False)
"""The ID of the channel the poll is in"""
message_id: "Snowflake_Type" = attrs.field(repr=False)
"""The ID of the message the poll is in"""
answer_id: int = attrs.field(repr=False)
"""The ID of the answer the user voted for"""
guild_id: "Optional[Snowflake_Type]" = attrs.field(repr=False, default=None)
"""The ID of the guild the poll is in"""

def get_message(self) -> "Optional[Message]":
"""Get the message object if it is cached"""
return self.client.cache.get_message(self.channel_id, self.message_id)

def get_user(self) -> "Optional[User]":
"""Get the user object if it is cached"""
return self.client.get_user(self.user_id)

def get_channel(self) -> "Optional[TYPE_ALL_CHANNEL]":
"""Get the channel object if it is cached"""
return self.client.get_channel(self.channel_id)

def get_guild(self) -> "Optional[Guild]":
"""Get the guild object if it is cached"""
return self.client.get_guild(self.guild_id) if self.guild_id is not None else None

def get_poll(self) -> "Optional[Poll]":
"""Get the poll object if it is cached"""
message = self.get_message()
return message.poll if message is not None else None

async def fetch_message(self) -> "Message":
"""Fetch the message the poll is in"""
return await self.client.cache.fetch_message(self.channel_id, self.message_id)

async def fetch_user(self) -> "User":
"""Fetch the user that voted"""
return await self.client.fetch_user(self.user_id)

async def fetch_channel(self) -> "TYPE_ALL_CHANNEL":
"""Fetch the channel the poll is in"""
return await self.client.fetch_channel(self.channel_id)

async def fetch_guild(self) -> "Optional[Guild]":
"""Fetch the guild the poll is in"""
return await self.client.fetch_guild(self.guild_id) if self.guild_id is not None else None

async def fetch_poll(self) -> "Poll":
"""Fetch the poll object"""
message = await self.fetch_message()
return message.poll


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class MessagePollVoteAdd(BaseMessagePollEvent):
"""Dispatched when a user votes in a poll"""


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class MessagePollVoteRemove(BaseMessagePollEvent):
"""Dispatched when a user remotes a votes in a poll"""


@attrs.define(eq=False, order=False, hash=False, kw_only=False)
class PresenceUpdate(BaseEvent):
"""A user's presence has changed."""
Expand Down
38 changes: 38 additions & 0 deletions interactions/api/events/processors/message_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,41 @@ async def _on_raw_message_delete_bulk(self, event: "RawGatewayEvent") -> None:
event.data.get("ids"),
)
)

@Processor.define()
async def _on_raw_message_poll_vote_add(self, event: "RawGatewayEvent") -> None:
"""
Process raw message poll vote add event and dispatch a processed poll vote add event.

Args:
event: raw poll vote add event

"""
self.dispatch(
events.MessagePollVoteAdd(
event.data.get("guild_id", None),
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
)
)

@Processor.define()
async def _on_raw_message_poll_vote_remove(self, event: "RawGatewayEvent") -> None:
"""
Process raw message poll vote remove event and dispatch a processed poll vote remove event.

Args:
event: raw poll vote remove event

"""
self.dispatch(
events.MessagePollVoteRemove(
event.data.get("guild_id", None),
event.data["channel_id"],
event.data["message_id"],
event.data["user_id"],
event.data["option"],
)
)
63 changes: 62 additions & 1 deletion interactions/api/http/http_requests/messages.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import TYPE_CHECKING, cast
from typing import TYPE_CHECKING, cast, TypedDict

import discord_typings

from interactions.models.internal.protocols import CanRequest
from interactions.client.utils.serializer import dict_filter_none
from ..route import Route

__all__ = ("MessageRequests",)
Expand All @@ -13,6 +14,10 @@
from interactions import UPLOADABLE_TYPE


class GetAnswerVotersData(TypedDict):
users: list[discord_typings.UserData]


class MessageRequests(CanRequest):
async def create_message(
self,
Expand Down Expand Up @@ -175,3 +180,59 @@ async def crosspost_message(
)
)
return cast(discord_typings.MessageData, result)

async def get_answer_voters(
self,
channel_id: "Snowflake_Type",
message_id: "Snowflake_Type",
answer_id: int,
after: "Snowflake_Type | None" = None,
limit: int = 25,
) -> GetAnswerVotersData:
"""
Get a list of users that voted for this specific answer.

Args:
channel_id: Channel the message is in
message_id: The message with the poll
answer_id: The answer to get voters for
after: Get messages after this user ID
limit: The max number of users to return (default 25, max 100)

Returns:
GetAnswerVotersData: A response that has a list of users that voted for the answer

"""
result = await self.request(
Route(
"GET",
"/channels/{channel_id}/polls/{message_id}/answers/{answer_id}",
channel_id=channel_id,
message_id=message_id,
answer_id=answer_id,
),
params=dict_filter_none({"after": after, "limit": limit}),
)
return cast(GetAnswerVotersData, result)

async def end_poll(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type") -> discord_typings.MessageData:
"""
Ends a poll. Only can end polls from the current bot.

Args:
channel_id: Channel the message is in
message_id: The message with the poll

Returns:
message object

"""
result = await self.request(
Route(
"POST",
"/channels/{channel_id}/polls/{message_id}/expire",
channel_id=channel_id,
message_id=message_id,
)
)
return cast(discord_typings.MessageData, result)
4 changes: 4 additions & 0 deletions interactions/client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
MISSING,
MENTION_PREFIX,
PREMIUM_GUILD_LIMITS,
POLL_MAX_ANSWERS,
POLL_MAX_DURATION_HOURS,
Absent,
T,
T_co,
Expand Down Expand Up @@ -61,6 +63,8 @@
"EMBED_MAX_FIELDS",
"EMBED_TOTAL_MAX",
"EMBED_FIELD_VALUE_LENGTH",
"POLL_MAX_ANSWERS",
"POLL_MAX_DURATION_HOURS",
"Singleton",
"Sentinel",
"GlobalScope",
Expand Down
5 changes: 5 additions & 0 deletions interactions/client/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@
"NON_RESUMABLE_WEBSOCKET_CLOSE_CODES",
"CLIENT_FEATURE_FLAGS",
"has_client_feature",
"POLL_MAX_ANSWERS",
"POLL_MAX_DURATION_HOURS",
)

_ver_info = sys.version_info
Expand Down Expand Up @@ -130,6 +132,9 @@ def get_logger() -> logging.Logger:
EMBED_TOTAL_MAX = 6000
EMBED_FIELD_VALUE_LENGTH = 1024

POLL_MAX_ANSWERS = 10
POLL_MAX_DURATION_HOURS = 168


class Singleton(type):
_instances: ClassVar[dict] = {}
Expand Down
4 changes: 4 additions & 0 deletions interactions/client/mixins/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from interactions.models.discord.components import BaseComponent
from interactions.models.discord.embed import Embed
from interactions.models.discord.message import AllowedMentions, Message, MessageReference
from interactions.models.discord.poll import Poll
from interactions.models.discord.sticker import Sticker
from interactions.models.discord.snowflake import Snowflake_Type

Expand Down Expand Up @@ -49,6 +50,7 @@ async def send(
delete_after: Optional[float] = None,
nonce: Optional[str | int] = None,
enforce_nonce: bool = False,
poll: "Optional[Poll | dict]" = None,
**kwargs: Any,
) -> "Message":
"""
Expand All @@ -73,6 +75,7 @@ async def send(
enforce_nonce: If enabled and nonce is present, it will be checked for uniqueness in the past few minutes. \
If another message was created by the same author with the same nonce, that message will be returned \
and no new message will be created.
poll: A poll.

Returns:
New message object that was sent.
Expand Down Expand Up @@ -115,6 +118,7 @@ async def send(
flags=flags,
nonce=nonce,
enforce_nonce=enforce_nonce,
poll=poll,
**kwargs,
)

Expand Down
4 changes: 4 additions & 0 deletions interactions/ext/hybrid_commands/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
Attachment,
process_message_payload,
TYPE_MESSAGEABLE_CHANNEL,
Poll,
)
from interactions.models.discord.enums import ContextType
from interactions.client.mixins.send import SendMixin
Expand Down Expand Up @@ -309,6 +310,7 @@ async def send(
suppress_embeds: bool = False,
silent: bool = False,
flags: Optional[Union[int, "MessageFlags"]] = None,
poll: "Optional[Poll | dict]" = None,
delete_after: Optional[float] = None,
ephemeral: bool = False,
**kwargs: Any,
Expand All @@ -330,6 +332,7 @@ async def send(
suppress_embeds: Should embeds be suppressed on this send
silent: Should this message be sent without triggering a notification.
flags: Message flags to apply.
poll: A poll.
delete_after: Delete message after this many seconds.
ephemeral: Should this message be sent as ephemeral (hidden) - only works with interactions

Expand Down Expand Up @@ -358,6 +361,7 @@ async def send(
file=file,
tts=tts,
flags=flags,
poll=poll,
delete_after=delete_after,
pass_self_into_delete=bool(self._slash_ctx),
**kwargs,
Expand Down
Loading
Loading