Skip to content

Commit 7394cd6

Browse files
authored
fix: Reimplement manual sharding/presence, fix forum tag implementation (#1115)
* fix: Reimplement manual sharding/presence instantiation. (This was accidentally removed per gateway rework) * refactor: Reorganise tag creation/updating/deletion to non-deprecated endpoints and make it cache-reflective.
1 parent 527f320 commit 7394cd6

File tree

4 files changed

+73
-12
lines changed

4 files changed

+73
-12
lines changed

interactions/api/gateway/client.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def __init__(
122122
intents: Intents,
123123
session_id: Optional[str] = MISSING,
124124
sequence: Optional[int] = MISSING,
125+
shards: Optional[List[Tuple[int]]] = MISSING,
126+
presence: Optional[ClientPresence] = MISSING,
125127
) -> None:
126128
"""
127129
:param token: The token of the application for connecting to the Gateway.
@@ -132,6 +134,10 @@ def __init__(
132134
:type session_id?: Optional[str]
133135
:param sequence?: The identifier sequence if trying to reconnect. Defaults to ``None``.
134136
:type sequence?: Optional[int]
137+
:param shards?: The list of shards for the application's initial connection, if provided. Defaults to ``None``.
138+
:type shards?: Optional[List[Tuple[int]]]
139+
:param presence?: The presence shown on an application once first connected. Defaults to ``None``.
140+
:type presence?: Optional[ClientPresence]
135141
"""
136142
try:
137143
self._loop = get_event_loop() if version_info < (3, 10) else get_running_loop()
@@ -161,8 +167,8 @@ def __init__(
161167
}
162168

163169
self._intents: Intents = intents
164-
self.__shard: Optional[List[Tuple[int]]] = None
165-
self.__presence: Optional[ClientPresence] = None
170+
self.__shard: Optional[List[Tuple[int]]] = None if shards is MISSING else shards
171+
self.__presence: Optional[ClientPresence] = None if presence is MISSING else presence
166172

167173
self._task: Optional[Task] = None
168174
self.__heartbeat_event = Event(loop=self._loop) if version_info < (3, 10) else Event()

interactions/api/http/channel.py

Lines changed: 60 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from ..error import LibraryException
55
from ..models.channel import Channel
66
from ..models.message import Message
7+
from ..models.misc import Snowflake
78
from .request import _Request
89
from .route import Route
910

@@ -312,8 +313,10 @@ async def create_tag(
312313
self,
313314
channel_id: int,
314315
name: str,
316+
moderated: bool = False,
315317
emoji_id: Optional[int] = None,
316318
emoji_name: Optional[str] = None,
319+
reason: Optional[str] = None,
317320
) -> dict:
318321
"""
319322
Create a new tag.
@@ -324,25 +327,41 @@ async def create_tag(
324327
325328
:param channel_id: Channel ID snowflake.
326329
:param name: The name of the tag
330+
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
327331
:param emoji_id: The ID of the emoji to use for the tag
328332
:param emoji_name: The name of the emoji to use for the tag
333+
:param reason: The reason for the creating the tag, if any.
334+
:return: A Forum tag.
329335
"""
330336

331-
_dct = {"name": name}
337+
# This *assumes* cache is up-to-date.
338+
339+
_channel = self.cache[Channel].get(Snowflake(channel_id))
340+
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form
341+
342+
_dct = {"name": name, "moderated": moderated}
332343
if emoji_id:
333344
_dct["emoji_id"] = emoji_id
334345
if emoji_name:
335346
_dct["emoji_name"] = emoji_name
336347

337-
return await self._req.request(Route("POST", f"/channels/{channel_id}/tags"), json=_dct)
348+
_tags.append(_dct)
349+
350+
updated_channel = await self.modify_channel(
351+
channel_id, {"available_tags": _tags}, reason=reason
352+
)
353+
_channel_obj = Channel(**updated_channel, _client=self)
354+
return _channel_obj.available_tags[-1]._json
338355

339356
async def edit_tag(
340357
self,
341358
channel_id: int,
342359
tag_id: int,
343360
name: str,
361+
moderated: Optional[bool] = None,
344362
emoji_id: Optional[int] = None,
345363
emoji_name: Optional[str] = None,
364+
reason: Optional[str] = None,
346365
) -> dict:
347366
"""
348367
Update a tag.
@@ -351,28 +370,62 @@ async def edit_tag(
351370
Can either have an emoji_id or an emoji_name, but not both.
352371
emoji_id is meant for custom emojis, emoji_name is meant for unicode emojis.
353372
373+
The object returns *will* have a different tag ID.
374+
354375
:param channel_id: Channel ID snowflake.
355376
:param tag_id: The ID of the tag to update.
377+
:param moderated: Whether the tag can only be assigned to moderators or not. Defaults to ``False``
356378
:param name: The new name of the tag
357379
:param emoji_id: The ID of the emoji to use for the tag
358380
:param emoji_name: The name of the emoji to use for the tag
381+
:param reason: The reason for deleting the tag, if any.
382+
383+
:return The updated tag object.
359384
"""
360385

361-
_dct = {"name": name}
386+
# This *assumes* cache is up-to-date.
387+
388+
_channel = self.cache[Channel].get(Snowflake(channel_id))
389+
_tags = [_._json for _ in _channel.available_tags] # list of tags in dict form
390+
391+
_old_tag = [tag for tag in _tags if tag["id"] == tag_id][0]
392+
393+
_tags.remove(_old_tag)
394+
395+
_dct = {"name": name, "tag_id": tag_id}
396+
if moderated:
397+
_dct["moderated"] = moderated
362398
if emoji_id:
363399
_dct["emoji_id"] = emoji_id
364400
if emoji_name:
365401
_dct["emoji_name"] = emoji_name
366402

367-
return await self._req.request(
368-
Route("PUT", f"/channels/{channel_id}/tags/{tag_id}"), json=_dct
403+
_tags.append(_dct)
404+
405+
updated_channel = await self.modify_channel(
406+
channel_id, {"available_tags": _tags}, reason=reason
369407
)
408+
_channel_obj = Channel(**updated_channel, _client=self)
409+
410+
self.cache[Channel].merge(_channel_obj)
411+
412+
return [tag for tag in _channel_obj.available_tags if tag.name == name][0]
370413

371-
async def delete_tag(self, channel_id: int, tag_id: int) -> None: # wha?
414+
async def delete_tag(self, channel_id: int, tag_id: int, reason: Optional[str] = None) -> None:
372415
"""
373416
Delete a forum tag.
374417
375418
:param channel_id: Channel ID snowflake.
376419
:param tag_id: The ID of the tag to delete
420+
:param reason: The reason for deleting the tag, if any.
377421
"""
378-
return await self._req.request(Route("DELETE", f"/channels/{channel_id}/tags/{tag_id}"))
422+
_channel = self.cache[Channel].get(Snowflake(channel_id))
423+
_tags = [_._json for _ in _channel.available_tags]
424+
425+
_old_tag = [tag for tag in _tags if tag["id"] == Snowflake(tag_id)][0]
426+
427+
_tags.remove(_old_tag)
428+
429+
request = await self.modify_channel(channel_id, {"available_tags": _tags}, reason=reason)
430+
431+
self.cache[Channel].merge(Channel(**request, _client=self))

interactions/api/http/thread.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ async def create_thread(
159159
reason: Optional[str] = None,
160160
) -> dict:
161161
"""
162-
From a given channel, create a Thread with an optional message to start with..
162+
From a given channel, create a Thread with an optional message to start with.
163163
164164
:param channel_id: The ID of the channel to create this thread in
165165
:param name: The name of the thread
@@ -212,7 +212,7 @@ async def create_thread_in_forum(
212212
:param name: The name of the thread
213213
:param auto_archive_duration: duration in minutes to automatically archive the thread after recent activity,
214214
can be set to: 60, 1440, 4320, 10080
215-
:param message_payload: The payload/dictionary contents of the first message in the forum thread.
215+
:param message: The payload/dictionary contents of the first message in the forum thread.
216216
:param applied_tags: List of tag ids that can be applied to the forum, if any.
217217
:param files: An optional list of files to send attached to the message.
218218
:param rate_limit_per_user: Seconds a user has to wait before sending another message (0 to 21600), if given.

interactions/client/bot.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,13 @@ def __init__(
8080
self._loop: AbstractEventLoop = get_event_loop()
8181
self._http: HTTPClient = token
8282
self._intents: Intents = kwargs.get("intents", Intents.DEFAULT)
83-
self._websocket: WSClient = WSClient(token=token, intents=self._intents)
8483
self._shards: List[Tuple[int]] = kwargs.get("shards", [])
8584
self._commands: List[Command] = []
8685
self._default_scope = kwargs.get("default_scope")
8786
self._presence = kwargs.get("presence")
87+
self._websocket: WSClient = WSClient(
88+
token=token, intents=self._intents, shards=self._shards, presence=self._presence
89+
)
8890
self._token = token
8991
self._extensions = {}
9092
self._scopes = set([])

0 commit comments

Comments
 (0)