|
| 1 | +from typing import Optional, Union, Dict, Any |
| 2 | +from typing_extensions import Self |
| 3 | + |
| 4 | +import attrs |
| 5 | + |
| 6 | +from interactions.client.const import MISSING |
| 7 | +from interactions.client.utils.attr_converters import ( |
| 8 | + optional, |
| 9 | + timestamp_converter, |
| 10 | +) |
| 11 | +from interactions.client.mixins.serialization import DictSerializationMixin |
| 12 | +from interactions.client.utils.serializer import no_export_meta |
| 13 | +from interactions.models.discord.emoji import PartialEmoji, process_emoji |
| 14 | +from interactions.models.discord.enums import PollLayoutType |
| 15 | +from interactions.models.discord.timestamp import Timestamp |
| 16 | + |
| 17 | +__all__ = ( |
| 18 | + "PollMedia", |
| 19 | + "PollAnswer", |
| 20 | + "PollAnswerCount", |
| 21 | + "PollResults", |
| 22 | + "Poll", |
| 23 | +) |
| 24 | + |
| 25 | + |
| 26 | +@attrs.define(eq=False, order=False, hash=False, kw_only=True) |
| 27 | +class PollMedia(DictSerializationMixin): |
| 28 | + text: Optional[str] = attrs.field(repr=False, default=None) |
| 29 | + """The text of the field.""" |
| 30 | + emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict)) |
| 31 | + """The emoji of the field.""" |
| 32 | + |
| 33 | + @classmethod |
| 34 | + def create(cls, *, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self: |
| 35 | + """ |
| 36 | + Create a PollMedia object, used for questions and answers for polls. |
| 37 | +
|
| 38 | + Args: |
| 39 | + text: The text of the field. |
| 40 | + emoji: The emoji of the field. |
| 41 | +
|
| 42 | + Returns: |
| 43 | + A PollMedia object. |
| 44 | +
|
| 45 | + """ |
| 46 | + if not text and not emoji: |
| 47 | + raise ValueError("Either text or emoji must be provided.") |
| 48 | + |
| 49 | + return cls(text=text, emoji=process_emoji(emoji)) |
| 50 | + |
| 51 | + |
| 52 | +@attrs.define(eq=False, order=False, hash=False, kw_only=True) |
| 53 | +class PollAnswer(DictSerializationMixin): |
| 54 | + poll_media: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict) |
| 55 | + """The data of the answer.""" |
| 56 | + answer_id: Optional[int] = attrs.field(repr=False, default=None) |
| 57 | + """The ID of the answer. This is only returned for polls that have been given by Discord's API.""" |
| 58 | + |
| 59 | + |
| 60 | +@attrs.define(eq=False, order=False, hash=False, kw_only=True) |
| 61 | +class PollAnswerCount(DictSerializationMixin): |
| 62 | + id: int = attrs.field(repr=False) |
| 63 | + """The answer ID of the answer.""" |
| 64 | + count: int = attrs.field(repr=False, default=0) |
| 65 | + """The number of votes for this answer.""" |
| 66 | + me_voted: bool = attrs.field(repr=False, default=False) |
| 67 | + """Whether the current user voted for this answer.""" |
| 68 | + |
| 69 | + |
| 70 | +@attrs.define(eq=False, order=False, hash=False, kw_only=True) |
| 71 | +class PollResults(DictSerializationMixin): |
| 72 | + is_finalized: bool = attrs.field(repr=False, default=False) |
| 73 | + """Whether the votes have been precisely counted.""" |
| 74 | + answer_counts: list[PollAnswerCount] = attrs.field(repr=False, factory=list, converter=PollAnswerCount.from_list) |
| 75 | + """The counts for each answer.""" |
| 76 | + |
| 77 | + |
| 78 | +@attrs.define(eq=False, order=False, hash=False, kw_only=True) |
| 79 | +class Poll(DictSerializationMixin): |
| 80 | + question: PollMedia = attrs.field(repr=False) |
| 81 | + """The question of the poll. Only text media is supported.""" |
| 82 | + answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list) |
| 83 | + """Each of the answers available in the poll, up to 10.""" |
| 84 | + expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=timestamp_converter) |
| 85 | + """Number of hours the poll is open for, up to 7 days.""" |
| 86 | + allow_multiselect: bool = attrs.field(repr=False, default=False, metadata=no_export_meta) |
| 87 | + """Whether a user can select multiple answers.""" |
| 88 | + layout_type: PollLayoutType = attrs.field(repr=False, default=PollLayoutType.DEFAULT, converter=PollLayoutType) |
| 89 | + """The layout type of the poll.""" |
| 90 | + results: Optional[PollResults] = attrs.field(repr=False, default=None, converter=optional(PollResults.from_dict)) |
| 91 | + """The results of the poll, if the polls is finished.""" |
| 92 | + |
| 93 | + _duration: int = attrs.field(repr=False, default=0) |
| 94 | + """How long, in hours, the poll will be open for (up to 7 days). This is only used when creating polls.""" |
| 95 | + |
| 96 | + @classmethod |
| 97 | + def create( |
| 98 | + cls, question: str, duration: int, *, allow_multiselect: bool = False, answers: Optional[list[PollMedia]] = None |
| 99 | + ) -> Self: |
| 100 | + """ |
| 101 | + Create a Poll object for sending. |
| 102 | +
|
| 103 | + Args: |
| 104 | + question: The question of the poll. |
| 105 | + duration: How long, in hours, the poll will be open for (up to 7 days). |
| 106 | + allow_multiselect: Whether a user can select multiple answers. |
| 107 | + answers: Each of the answers available in the poll, up to 10. |
| 108 | +
|
| 109 | + Returns: |
| 110 | + A Poll object. |
| 111 | +
|
| 112 | + """ |
| 113 | + if answers: |
| 114 | + media_to_answers = [PollAnswer(poll_media=answer) for answer in answers] |
| 115 | + else: |
| 116 | + media_to_answers = [] |
| 117 | + |
| 118 | + return cls( |
| 119 | + question=PollMedia(text=question), |
| 120 | + duration=duration, |
| 121 | + allow_multiselect=allow_multiselect, |
| 122 | + answers=media_to_answers, |
| 123 | + ) |
| 124 | + |
| 125 | + @answers.validator |
| 126 | + def _answers_validation(self, attribute: str, value: Any) -> None: |
| 127 | + if len(value) > 10: |
| 128 | + raise ValueError("A poll can have at most 10 answers.") |
| 129 | + |
| 130 | + @_duration.validator |
| 131 | + def _duration_validation(self, attribute: str, value: int) -> None: |
| 132 | + if value < 0 or value > 168: |
| 133 | + raise ValueError("The duration must be between 0 and 168 hours (7 days).") |
| 134 | + |
| 135 | + def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> None: |
| 136 | + """ |
| 137 | + Adds an answer to the poll. |
| 138 | +
|
| 139 | + Args: |
| 140 | + text: The text of the answer. |
| 141 | + emoji: The emoji for the answer. |
| 142 | +
|
| 143 | + """ |
| 144 | + if not text and not emoji: |
| 145 | + raise ValueError("Either text or emoji must be provided") |
| 146 | + |
| 147 | + self.answers.append(PollAnswer(poll_media=PollMedia.create(text=text, emoji=emoji))) |
| 148 | + self._answers_validation("answers", self.answers) |
| 149 | + |
| 150 | + def to_dict(self) -> Dict[str, Any]: |
| 151 | + data = super().to_dict() |
| 152 | + |
| 153 | + data["duration"] = self._duration |
| 154 | + data.pop("_duration", None) |
| 155 | + return data |
0 commit comments