|
1 | 1 | import asyncio
|
2 | 2 | import base64
|
3 | 3 | import re
|
| 4 | +from collections import namedtuple |
4 | 5 | from dataclasses import dataclass
|
5 | 6 | from typing import (
|
6 | 7 | TYPE_CHECKING,
|
|
29 | 30 | from interactions.models.discord.emoji import process_emoji_req_format
|
30 | 31 | from interactions.models.discord.file import UPLOADABLE_TYPE
|
31 | 32 | from interactions.models.discord.poll import Poll
|
| 33 | +from interactions.models.misc.iterator import AsyncIterator |
32 | 34 |
|
33 | 35 | from .base import DiscordObject
|
34 | 36 | from .enums import (
|
|
68 | 70 |
|
69 | 71 | channel_mention = re.compile(r"<#(?P<id>[0-9]{17,})>")
|
70 | 72 |
|
| 73 | +class PollAnswerVotersIterator(AsyncIterator): |
| 74 | + def __init__(self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None) -> None: |
| 75 | + self.message: "Message" = message |
| 76 | + self.answer_id = answer_id |
| 77 | + self.after: Snowflake_Type | None = after |
| 78 | + self._more: bool = True |
| 79 | + super().__init__(limit) |
| 80 | + |
| 81 | + async def fetch(self) -> list["models.User"]: |
| 82 | + if not self.last: |
| 83 | + self.last = namedtuple("temp", "id") |
| 84 | + self.last.id = self.after |
| 85 | + |
| 86 | + rcv = await self.message._client.http.get_answer_voters( |
| 87 | + self.message._channel_id, |
| 88 | + self.message.id, |
| 89 | + self.answer_id, |
| 90 | + limit=self.get_limit, |
| 91 | + after=to_snowflake(self.last.id) if self.last.id else None, |
| 92 | + ) |
| 93 | + if not rcv: |
| 94 | + raise asyncio.QueueEmpty |
| 95 | + |
| 96 | + users = [self.message._client.cache.place_user_data(user_data) for user_data in rcv["users"]] |
| 97 | + return users |
| 98 | + |
71 | 99 |
|
72 | 100 | @attrs.define(eq=False, order=False, hash=False, kw_only=True)
|
73 | 101 | class Attachment(DiscordObject):
|
@@ -361,6 +389,8 @@ class Message(BaseMessage):
|
361 | 389 | """Data showing the source of a crosspost, channel follow add, pin, or reply message"""
|
362 | 390 | flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags)
|
363 | 391 | """Message flags combined as a bitfield"""
|
| 392 | + poll: Optional[Poll] = attrs.field(repr=False, default=None, converter=optional_c(Poll.from_dict)) |
| 393 | + """A poll.""" |
364 | 394 | interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None)
|
365 | 395 | """Sent if the message is a response to an Interaction"""
|
366 | 396 | components: Optional[List["models.ActionRow"]] = attrs.field(repr=False, default=None)
|
@@ -593,6 +623,18 @@ def proto_url(self) -> str:
|
593 | 623 | """A URL like `jump_url` that uses protocols."""
|
594 | 624 | return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}"
|
595 | 625 |
|
| 626 | + def answer_voters(self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None) -> PollAnswerVotersIterator: |
| 627 | + """ |
| 628 | + An async iterator for getting the voters for an answer in the poll this message has. |
| 629 | +
|
| 630 | + Args: |
| 631 | + answer_id: The answer to get voters for |
| 632 | + after: Get messages after this user ID |
| 633 | + limit: The max number of users to return (default 25, max 100) |
| 634 | +
|
| 635 | + """ |
| 636 | + return PollAnswerVotersIterator(self, answer_id, limit, before) |
| 637 | + |
596 | 638 | async def edit(
|
597 | 639 | self,
|
598 | 640 | *,
|
@@ -848,6 +890,14 @@ async def publish(self) -> None:
|
848 | 890 | """
|
849 | 891 | await self._client.http.crosspost_message(self._channel_id, self.id)
|
850 | 892 |
|
| 893 | + async def end_poll(self) -> "Message": |
| 894 | + """ |
| 895 | + Ends the poll in this message. |
| 896 | + """ |
| 897 | + message_data = await self._client.http.end_poll(self._channel_id, self.id) |
| 898 | + if message_data: |
| 899 | + return self._client.cache.place_message_data(message_data) |
| 900 | + |
851 | 901 |
|
852 | 902 | def process_allowed_mentions(allowed_mentions: Optional[Union[AllowedMentions, dict]]) -> Optional[dict]:
|
853 | 903 | """
|
|
0 commit comments