|
1 | 1 | import asyncio
|
2 | 2 | import base64
|
3 | 3 | import re
|
| 4 | +from collections import namedtuple |
4 | 5 | from dataclasses import dataclass
|
5 | 6 | from typing import (
|
6 | 7 | TYPE_CHECKING,
|
|
29 | 30 | from interactions.models.discord.emoji import process_emoji_req_format
|
30 | 31 | from interactions.models.discord.file import UPLOADABLE_TYPE
|
31 | 32 | from interactions.models.discord.poll import Poll
|
| 33 | +from interactions.models.misc.iterator import AsyncIterator |
32 | 34 |
|
33 | 35 | from .base import DiscordObject
|
34 | 36 | from .enums import (
|
|
70 | 72 | )
|
71 | 73 |
|
72 | 74 | channel_mention = re.compile(r"<#(?P<id>[0-9]{17,})>")
|
| 75 | + |
| 76 | + |
| 77 | +class PollAnswerVotersIterator(AsyncIterator): |
| 78 | + def __init__( |
| 79 | + self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None |
| 80 | + ) -> None: |
| 81 | + self.message: "Message" = message |
| 82 | + self.answer_id = answer_id |
| 83 | + self.after: Snowflake_Type | None = after |
| 84 | + self._more: bool = True |
| 85 | + super().__init__(limit) |
| 86 | + |
| 87 | + async def fetch(self) -> list["models.User"]: |
| 88 | + if not self.last: |
| 89 | + self.last = namedtuple("temp", "id") |
| 90 | + self.last.id = self.after |
| 91 | + |
| 92 | + rcv = await self.message._client.http.get_answer_voters( |
| 93 | + self.message._channel_id, |
| 94 | + self.message.id, |
| 95 | + self.answer_id, |
| 96 | + limit=self.get_limit, |
| 97 | + after=to_snowflake(self.last.id) if self.last.id else None, |
| 98 | + ) |
| 99 | + if not rcv: |
| 100 | + raise asyncio.QueueEmpty |
| 101 | + |
| 102 | + users = [self.message._client.cache.place_user_data(user_data) for user_data in rcv["users"]] |
| 103 | + return users |
73 | 104 |
|
74 | 105 |
|
75 | 106 | @attrs.define(eq=False, order=False, hash=False, kw_only=True)
|
@@ -408,6 +439,8 @@ class Message(BaseMessage):
|
408 | 439 | """Data showing the source of a crosspost, channel follow add, pin, or reply message"""
|
409 | 440 | flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags)
|
410 | 441 | """Message flags combined as a bitfield"""
|
| 442 | + poll: Optional[Poll] = attrs.field(repr=False, default=None, converter=optional_c(Poll.from_dict)) |
| 443 | + """A poll.""" |
411 | 444 | interaction_metadata: Optional[MessageInteractionMetadata] = attrs.field(repr=False, default=None)
|
412 | 445 | """Sent if the message is a response to an Interaction"""
|
413 | 446 | interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None)
|
@@ -644,6 +677,20 @@ def jump_url(self) -> str:
|
644 | 677 | def proto_url(self) -> str:
|
645 | 678 | """A URL like `jump_url` that uses protocols."""
|
646 | 679 | return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}"
|
| 680 | + |
| 681 | + def answer_voters( |
| 682 | + self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None |
| 683 | + ) -> PollAnswerVotersIterator: |
| 684 | + """ |
| 685 | + An async iterator for getting the voters for an answer in the poll this message has. |
| 686 | +
|
| 687 | + Args: |
| 688 | + answer_id: The answer to get voters for |
| 689 | + after: Get messages after this user ID |
| 690 | + limit: The max number of users to return (default 25, max 100) |
| 691 | +
|
| 692 | + """ |
| 693 | + return PollAnswerVotersIterator(self, answer_id, limit, before) |
647 | 694 |
|
648 | 695 | async def edit(
|
649 | 696 | self,
|
@@ -900,6 +947,12 @@ async def publish(self) -> None:
|
900 | 947 | """
|
901 | 948 | await self._client.http.crosspost_message(self._channel_id, self.id)
|
902 | 949 |
|
| 950 | + async def end_poll(self) -> "Message": |
| 951 | + """Ends the poll contained in this message.""" |
| 952 | + message_data = await self._client.http.end_poll(self._channel_id, self.id) |
| 953 | + if message_data: |
| 954 | + return self._client.cache.place_message_data(message_data) |
| 955 | + |
903 | 956 |
|
904 | 957 | def process_allowed_mentions(allowed_mentions: Optional[Union[AllowedMentions, dict]]) -> Optional[dict]:
|
905 | 958 | """
|
|
0 commit comments