Skip to content

Commit ae2b49e

Browse files
committed
feat: add methods to interact with http methods
1 parent cd4f16f commit ae2b49e

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed

interactions/models/discord/message.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import asyncio
22
import base64
33
import re
4+
from collections import namedtuple
45
from dataclasses import dataclass
56
from typing import (
67
TYPE_CHECKING,
@@ -29,6 +30,7 @@
2930
from interactions.models.discord.emoji import process_emoji_req_format
3031
from interactions.models.discord.file import UPLOADABLE_TYPE
3132
from interactions.models.discord.poll import Poll
33+
from interactions.models.misc.iterator import AsyncIterator
3234

3335
from .base import DiscordObject
3436
from .enums import (
@@ -70,6 +72,35 @@
7072
)
7173

7274
channel_mention = re.compile(r"<#(?P<id>[0-9]{17,})>")
75+
76+
77+
class PollAnswerVotersIterator(AsyncIterator):
78+
def __init__(
79+
self, message: "Message", answer_id: int, limit: int = 25, after: Snowflake_Type | None = None
80+
) -> None:
81+
self.message: "Message" = message
82+
self.answer_id = answer_id
83+
self.after: Snowflake_Type | None = after
84+
self._more: bool = True
85+
super().__init__(limit)
86+
87+
async def fetch(self) -> list["models.User"]:
88+
if not self.last:
89+
self.last = namedtuple("temp", "id")
90+
self.last.id = self.after
91+
92+
rcv = await self.message._client.http.get_answer_voters(
93+
self.message._channel_id,
94+
self.message.id,
95+
self.answer_id,
96+
limit=self.get_limit,
97+
after=to_snowflake(self.last.id) if self.last.id else None,
98+
)
99+
if not rcv:
100+
raise asyncio.QueueEmpty
101+
102+
users = [self.message._client.cache.place_user_data(user_data) for user_data in rcv["users"]]
103+
return users
73104

74105

75106
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
@@ -408,6 +439,8 @@ class Message(BaseMessage):
408439
"""Data showing the source of a crosspost, channel follow add, pin, or reply message"""
409440
flags: MessageFlags = attrs.field(repr=False, default=MessageFlags.NONE, converter=MessageFlags)
410441
"""Message flags combined as a bitfield"""
442+
poll: Optional[Poll] = attrs.field(repr=False, default=None, converter=optional_c(Poll.from_dict))
443+
"""A poll."""
411444
interaction_metadata: Optional[MessageInteractionMetadata] = attrs.field(repr=False, default=None)
412445
"""Sent if the message is a response to an Interaction"""
413446
interaction: Optional["MessageInteraction"] = attrs.field(repr=False, default=None)
@@ -644,6 +677,20 @@ def jump_url(self) -> str:
644677
def proto_url(self) -> str:
645678
"""A URL like `jump_url` that uses protocols."""
646679
return f"discord://-/channels/{self._guild_id or '@me'}/{self._channel_id}/{self.id}"
680+
681+
def answer_voters(
682+
self, answer_id: int, limit: int = 0, before: Snowflake_Type | None = None
683+
) -> PollAnswerVotersIterator:
684+
"""
685+
An async iterator for getting the voters for an answer in the poll this message has.
686+
687+
Args:
688+
answer_id: The answer to get voters for
689+
after: Get messages after this user ID
690+
limit: The max number of users to return (default 25, max 100)
691+
692+
"""
693+
return PollAnswerVotersIterator(self, answer_id, limit, before)
647694

648695
async def edit(
649696
self,
@@ -900,6 +947,12 @@ async def publish(self) -> None:
900947
"""
901948
await self._client.http.crosspost_message(self._channel_id, self.id)
902949

950+
async def end_poll(self) -> "Message":
951+
"""Ends the poll contained in this message."""
952+
message_data = await self._client.http.end_poll(self._channel_id, self.id)
953+
if message_data:
954+
return self._client.cache.place_message_data(message_data)
955+
903956

904957
def process_allowed_mentions(allowed_mentions: Optional[Union[AllowedMentions, dict]]) -> Optional[dict]:
905958
"""

0 commit comments

Comments
 (0)