Skip to content

Commit 366ac5f

Browse files
committed
feat: initial work on polls
1 parent ecb9fca commit 366ac5f

File tree

6 files changed

+173
-0
lines changed

6 files changed

+173
-0
lines changed

interactions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@
243243
PartialEmojiConverter,
244244
PermissionOverwrite,
245245
Permissions,
246+
PollLayoutType,
246247
PremiumTier,
247248
PremiumType,
248249
process_allowed_mentions,
@@ -594,6 +595,7 @@
594595
"PartialEmojiConverter",
595596
"PermissionOverwrite",
596597
"Permissions",
598+
"PollLayoutType",
597599
"PREMIUM_GUILD_LIMITS",
598600
"PremiumTier",
599601
"PremiumType",

interactions/client/const.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@
8484
"NON_RESUMABLE_WEBSOCKET_CLOSE_CODES",
8585
"CLIENT_FEATURE_FLAGS",
8686
"has_client_feature",
87+
"POLL_MAX_ANSWERS",
88+
"POLL_MAX_DURATION_HOURS",
8789
)
8890

8991
_ver_info = sys.version_info
@@ -130,6 +132,9 @@ def get_logger() -> logging.Logger:
130132
EMBED_TOTAL_MAX = 6000
131133
EMBED_FIELD_VALUE_LENGTH = 1024
132134

135+
POLL_MAX_ANSWERS = 10
136+
POLL_MAX_DURATION_HOURS = 168
137+
133138

134139
class Singleton(type):
135140
_instances: ClassVar[dict] = {}

interactions/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@
122122
PartialEmoji,
123123
PermissionOverwrite,
124124
Permissions,
125+
PollLayoutType,
125126
PremiumTier,
126127
PremiumType,
127128
process_allowed_mentions,
@@ -523,6 +524,7 @@
523524
"PartialEmojiConverter",
524525
"PermissionOverwrite",
525526
"Permissions",
527+
"PollLayoutType",
526528
"PremiumTier",
527529
"PremiumType",
528530
"process_allowed_mentions",

interactions/models/discord/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@
104104
OnboardingPromptType,
105105
OverwriteType,
106106
Permissions,
107+
PollLayoutType,
107108
PremiumTier,
108109
PremiumType,
109110
ScheduledEventPrivacyLevel,
@@ -299,6 +300,7 @@
299300
"PartialEmoji",
300301
"PermissionOverwrite",
301302
"Permissions",
303+
"PollLayoutType",
302304
"PremiumTier",
303305
"PremiumType",
304306
"process_allowed_mentions",

interactions/models/discord/enums.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
"OnboardingPromptType",
3939
"OverwriteType",
4040
"Permissions",
41+
"PollLayoutType",
4142
"PremiumTier",
4243
"PremiumType",
4344
"ScheduledEventPrivacyLevel",
@@ -1134,3 +1135,9 @@ class EntitlementType(CursedIntEnum):
11341135
"""Entitlement was claimed by user for free as a Nitro Subscriber"""
11351136
APPLICATION_SUBSCRIPTION = 8
11361137
"""Entitlement was purchased as an app subscription"""
1138+
1139+
1140+
class PollLayoutType(CursedIntEnum):
1141+
"""The layout of a poll."""
1142+
1143+
DEFAULT = 1

interactions/models/discord/poll.py

Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
from typing import Optional, Union, Dict, Any
2+
from typing_extensions import Self
3+
4+
import attrs
5+
6+
from interactions.client.const import MISSING
7+
from interactions.client.utils.attr_converters import (
8+
optional,
9+
timestamp_converter,
10+
)
11+
from interactions.client.mixins.serialization import DictSerializationMixin
12+
from interactions.client.utils.serializer import no_export_meta
13+
from interactions.models.discord.emoji import PartialEmoji, process_emoji
14+
from interactions.models.discord.enums import PollLayoutType
15+
from interactions.models.discord.timestamp import Timestamp
16+
17+
__all__ = (
18+
"PollMedia",
19+
"PollAnswer",
20+
"PollAnswerCount",
21+
"PollResults",
22+
"Poll",
23+
)
24+
25+
26+
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
27+
class PollMedia(DictSerializationMixin):
28+
text: Optional[str] = attrs.field(repr=False, default=None)
29+
"""The text of the field."""
30+
emoji: Optional[PartialEmoji] = attrs.field(repr=False, default=None, converter=optional(PartialEmoji.from_dict))
31+
"""The emoji of the field."""
32+
33+
@classmethod
34+
def create(cls, *, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> Self:
35+
"""
36+
Create a PollMedia object, used for questions and answers for polls.
37+
38+
Args:
39+
text: The text of the field.
40+
emoji: The emoji of the field.
41+
42+
Returns:
43+
A PollMedia object.
44+
45+
"""
46+
if not text and not emoji:
47+
raise ValueError("Either text or emoji must be provided.")
48+
49+
return cls(text=text, emoji=process_emoji(emoji))
50+
51+
52+
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
53+
class PollAnswer(DictSerializationMixin):
54+
poll_media: PollMedia = attrs.field(repr=False, converter=PollMedia.from_dict)
55+
"""The data of the answer."""
56+
answer_id: Optional[int] = attrs.field(repr=False, default=None)
57+
"""The ID of the answer. This is only returned for polls that have been given by Discord's API."""
58+
59+
60+
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
61+
class PollAnswerCount(DictSerializationMixin):
62+
id: int = attrs.field(repr=False)
63+
"""The answer ID of the answer."""
64+
count: int = attrs.field(repr=False, default=0)
65+
"""The number of votes for this answer."""
66+
me_voted: bool = attrs.field(repr=False, default=False)
67+
"""Whether the current user voted for this answer."""
68+
69+
70+
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
71+
class PollResults(DictSerializationMixin):
72+
is_finalized: bool = attrs.field(repr=False, default=False)
73+
"""Whether the votes have been precisely counted."""
74+
answer_counts: list[PollAnswerCount] = attrs.field(repr=False, factory=list, converter=PollAnswerCount.from_list)
75+
"""The counts for each answer."""
76+
77+
78+
@attrs.define(eq=False, order=False, hash=False, kw_only=True)
79+
class Poll(DictSerializationMixin):
80+
question: PollMedia = attrs.field(repr=False)
81+
"""The question of the poll. Only text media is supported."""
82+
answers: list[PollAnswer] = attrs.field(repr=False, factory=list, converter=PollAnswer.from_list)
83+
"""Each of the answers available in the poll, up to 10."""
84+
expiry: Timestamp = attrs.field(repr=False, default=MISSING, converter=timestamp_converter)
85+
"""Number of hours the poll is open for, up to 7 days."""
86+
allow_multiselect: bool = attrs.field(repr=False, default=False, metadata=no_export_meta)
87+
"""Whether a user can select multiple answers."""
88+
layout_type: PollLayoutType = attrs.field(repr=False, default=PollLayoutType.DEFAULT, converter=PollLayoutType)
89+
"""The layout type of the poll."""
90+
results: Optional[PollResults] = attrs.field(repr=False, default=None, converter=optional(PollResults.from_dict))
91+
"""The results of the poll, if the polls is finished."""
92+
93+
_duration: int = attrs.field(repr=False, default=0)
94+
"""How long, in hours, the poll will be open for (up to 7 days). This is only used when creating polls."""
95+
96+
@classmethod
97+
def create(
98+
cls, question: str, duration: int, *, allow_multiselect: bool = False, answers: Optional[list[PollMedia]] = None
99+
) -> Self:
100+
"""
101+
Create a Poll object for sending.
102+
103+
Args:
104+
question: The question of the poll.
105+
duration: How long, in hours, the poll will be open for (up to 7 days).
106+
allow_multiselect: Whether a user can select multiple answers.
107+
answers: Each of the answers available in the poll, up to 10.
108+
109+
Returns:
110+
A Poll object.
111+
112+
"""
113+
if answers:
114+
media_to_answers = [PollAnswer(poll_media=answer) for answer in answers]
115+
else:
116+
media_to_answers = []
117+
118+
return cls(
119+
question=PollMedia(text=question),
120+
duration=duration,
121+
allow_multiselect=allow_multiselect,
122+
answers=media_to_answers,
123+
)
124+
125+
@answers.validator
126+
def _answers_validation(self, attribute: str, value: Any) -> None:
127+
if len(value) > 10:
128+
raise ValueError("A poll can have at most 10 answers.")
129+
130+
@_duration.validator
131+
def _duration_validation(self, attribute: str, value: int) -> None:
132+
if value < 0 or value > 168:
133+
raise ValueError("The duration must be between 0 and 168 hours (7 days).")
134+
135+
def add_answer(self, text: Optional[str] = None, emoji: Optional[Union[PartialEmoji, dict, str]] = None) -> None:
136+
"""
137+
Adds an answer to the poll.
138+
139+
Args:
140+
text: The text of the answer.
141+
emoji: The emoji for the answer.
142+
143+
"""
144+
if not text and not emoji:
145+
raise ValueError("Either text or emoji must be provided")
146+
147+
self.answers.append(PollAnswer(poll_media=PollMedia.create(text=text, emoji=emoji)))
148+
self._answers_validation("answers", self.answers)
149+
150+
def to_dict(self) -> Dict[str, Any]:
151+
data = super().to_dict()
152+
153+
data["duration"] = self._duration
154+
data.pop("_duration", None)
155+
return data

0 commit comments

Comments
 (0)