Skip to content

Commit e3b9419

Browse files
authored
feat: rate limit improvements (#1321)
* fix: refactor all http routes to generate buckets properly * fix: resolve routes regex missed * feat: allow concurrent api calls from the same bucket * feat: restore bucketLock.locked property * feat: further bucketlock improvements * feat: allow bucketlock blocking to be toggled
1 parent c987e31 commit e3b9419

File tree

13 files changed

+472
-186
lines changed

13 files changed

+472
-186
lines changed

interactions/api/http/http_client.py

Lines changed: 86 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -95,61 +95,105 @@ async def wait(self) -> None:
9595

9696

9797
class BucketLock:
98-
"""Manages the ratelimit for each bucket"""
99-
100-
def __init__(self) -> None:
101-
self._lock: asyncio.Lock = asyncio.Lock()
98+
"""Manages the rate limit for each bucket."""
99+
100+
DEFAULT_LIMIT = 1
101+
DEFAULT_REMAINING = 1
102+
DEFAULT_DELTA = 0.0
103+
104+
def __init__(self, header: CIMultiDictProxy | None = None) -> None:
105+
self._semaphore: asyncio.Semaphore | None = None
106+
if header is None:
107+
self.bucket_hash: str | None = None
108+
self.limit: int = self.DEFAULT_LIMIT
109+
self.remaining: int = self.DEFAULT_REMAINING
110+
self.delta: float = self.DEFAULT_DELTA
111+
else:
112+
self.ingest_ratelimit_header(header)
102113

103-
self.unlock_on_exit: bool = True
114+
self.logger = constants.get_logger()
104115

105-
self.bucket_hash: str | None = None
106-
self.limit: int = -1
107-
self.remaining: int = -1
108-
self.delta: float = 0.0
116+
self._lock: asyncio.Lock = asyncio.Lock()
109117

110118
def __repr__(self) -> str:
111-
return f"<BucketLock: {self.bucket_hash or 'Generic'}>"
119+
return f"<BucketLock: {self.bucket_hash or 'Generic'}, limit: {self.limit}, remaining: {self.remaining}, delta: {self.delta}>"
112120

113121
@property
114122
def locked(self) -> bool:
115-
"""Return True if lock is acquired."""
116-
return self._lock.locked()
117-
118-
def unlock(self) -> None:
119-
"""Unlock this bucket."""
120-
self._lock.release()
123+
"""Returns whether the bucket is locked."""
124+
if self._lock.locked():
125+
return True
126+
return self._semaphore is not None and self._semaphore.locked()
121127

122128
def ingest_ratelimit_header(self, header: CIMultiDictProxy) -> None:
123129
"""
124-
Ingests a discord rate limit header to configure this bucket lock.
130+
Ingests the rate limit header.
125131
126132
Args:
127-
header: A header from a http response
133+
header: The header to ingest, containing rate limit information.
134+
135+
Updates the bucket_hash, limit, remaining, and delta attributes with the information from the header.
128136
"""
129137
self.bucket_hash = header.get("x-ratelimit-bucket")
130-
self.limit = int(header.get("x-ratelimit-limit") or -1)
131-
self.remaining = int(header.get("x-ratelimit-remaining") or -1)
132-
self.delta = float(header.get("x-ratelimit-reset-after", 0.0))
133-
134-
async def blind_defer_unlock(self) -> None:
135-
"""Unlocks the BucketLock but doesn't wait for completion."""
136-
self.unlock_on_exit = False
137-
loop = asyncio.get_running_loop()
138-
loop.call_later(self.delta, self.unlock)
139-
140-
async def defer_unlock(self, reset_after: float | None = None) -> None:
141-
"""Unlocks the BucketLock after a specified delay."""
142-
self.unlock_on_exit = False
143-
await asyncio.sleep(reset_after or self.delta)
144-
self.unlock()
138+
self.limit = int(header.get("x-ratelimit-limit", self.DEFAULT_LIMIT))
139+
self.remaining = int(header.get("x-ratelimit-remaining", self.DEFAULT_REMAINING))
140+
self.delta = float(header.get("x-ratelimit-reset-after", self.DEFAULT_DELTA))
141+
142+
if self._semaphore is None or self._semaphore._value != self.limit:
143+
self._semaphore = asyncio.Semaphore(self.limit)
144+
145+
async def acquire(self) -> None:
146+
"""Acquires the semaphore."""
147+
if self._semaphore is None:
148+
return
149+
150+
if self._lock.locked():
151+
self.logger.debug(f"Waiting for bucket {self.bucket_hash} to unlock.")
152+
async with self._lock:
153+
pass
154+
155+
await self._semaphore.acquire()
156+
157+
def release(self) -> None:
158+
"""
159+
Releases the semaphore.
160+
161+
Note: If the bucket has been locked with lock_for_duration, this will not release the lock.
162+
"""
163+
if self._semaphore is None:
164+
return
165+
self._semaphore.release()
166+
167+
async def lock_for_duration(self, duration: float, block: bool = False) -> None:
168+
"""
169+
Locks the bucket for a given duration.
170+
171+
Args:
172+
duration: The duration to lock the bucket for.
173+
block: Whether to block until the bucket is unlocked.
174+
175+
Raises:
176+
RuntimeError: If the bucket is already locked.
177+
"""
178+
if self._lock.locked():
179+
raise RuntimeError("Attempted to lock a bucket that is already locked.")
180+
181+
async def _release() -> None:
182+
await asyncio.sleep(duration)
183+
self._lock.release()
184+
185+
if block:
186+
await self._lock.acquire()
187+
await _release()
188+
else:
189+
await self._lock.acquire()
190+
asyncio.create_task(_release())
145191

146192
async def __aenter__(self) -> None:
147-
await self._lock.acquire()
193+
await self.acquire()
148194

149195
async def __aexit__(self, *args) -> None:
150-
if self.unlock_on_exit and self._lock.locked():
151-
self.unlock()
152-
self.unlock_on_exit = True
196+
self.release()
153197

154198

155199
class HTTPClient(
@@ -363,7 +407,7 @@ async def request(
363407
f"Reset in {result.get('retry_after')} seconds",
364408
)
365409
# lock this resource and wait for unlock
366-
await lock.defer_unlock(float(result["retry_after"]))
410+
await lock.lock_for_duration(float(result["retry_after"]), block=True)
367411
else:
368412
# endpoint ratelimit is reached
369413
# 429's are unfortunately unavoidable, but we can attempt to avoid them
@@ -372,15 +416,17 @@ async def request(
372416
self.logger.warning,
373417
f"{route.endpoint} Has exceeded it's ratelimit ({lock.limit})! Reset in {lock.delta} seconds",
374418
)
375-
await lock.defer_unlock() # lock this route and wait for unlock
419+
await lock.lock_for_duration(lock.delta, block=True)
376420
continue
377421
if lock.remaining == 0:
378422
# Last call available in the bucket, lock until reset
379423
self.log_ratelimit(
380424
self.logger.debug,
381425
f"{route.endpoint} Has exhausted its ratelimit ({lock.limit})! Locking route for {lock.delta} seconds",
382426
)
383-
await lock.blind_defer_unlock() # lock this route, but continue processing the current response
427+
await lock.lock_for_duration(
428+
lock.delta
429+
) # lock this route, but continue processing the current response
384430

385431
elif response.status in {500, 502, 504}:
386432
# Server issues, retry

interactions/api/http/http_requests/channels.py

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ async def get_channel(self, channel_id: "Snowflake_Type") -> discord_typings.Cha
3333
channel
3434
3535
"""
36-
result = await self.request(Route("GET", f"/channels/{int(channel_id)}"))
36+
result = await self.request(Route("GET", "/channels/{channel_id}", channel_id=channel_id))
3737
return cast(discord_typings.ChannelData, result)
3838

3939
@overload
@@ -109,7 +109,9 @@ async def get_channel_messages(
109109
}
110110
params = dict_filter_none(params)
111111

112-
result = await self.request(Route("GET", f"/channels/{int(channel_id)}/messages"), params=params)
112+
result = await self.request(
113+
Route("GET", "/channels/{channel_id}/messages", channel_id=channel_id), params=params
114+
)
113115
return cast(list[discord_typings.MessageData], result)
114116

115117
async def create_guild_channel(
@@ -168,7 +170,9 @@ async def create_guild_channel(
168170
)
169171
payload = dict_filter_none(payload)
170172

171-
result = await self.request(Route("POST", f"/guilds/{int(guild_id)}/channels"), payload=payload, reason=reason)
173+
result = await self.request(
174+
Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), payload=payload, reason=reason
175+
)
172176
return cast(discord_typings.ChannelData, result)
173177

174178
async def move_channel(
@@ -200,7 +204,9 @@ async def move_channel(
200204
}
201205
payload = dict_filter_none(payload)
202206

203-
await self.request(Route("PATCH", f"/guilds/{int(guild_id)}/channels"), payload=payload, reason=reason)
207+
await self.request(
208+
Route("PATCH", "/guilds/{guild_id}/channels", guild_id=guild_id), payload=payload, reason=reason
209+
)
204210

205211
async def modify_channel(
206212
self, channel_id: "Snowflake_Type", data: dict, reason: str | None = None
@@ -217,7 +223,9 @@ async def modify_channel(
217223
Channel object on success
218224
219225
"""
220-
result = await self.request(Route("PATCH", f"/channels/{int(channel_id)}"), payload=data, reason=reason)
226+
result = await self.request(
227+
Route("PATCH", "/channels/{channel_id}", channel_id=channel_id), payload=data, reason=reason
228+
)
221229
return cast(discord_typings.ChannelData, result)
222230

223231
async def delete_channel(self, channel_id: "Snowflake_Type", reason: str | None = None) -> None:
@@ -229,7 +237,7 @@ async def delete_channel(self, channel_id: "Snowflake_Type", reason: str | None
229237
reason: An optional reason for the audit log
230238
231239
"""
232-
await self.request(Route("DELETE", f"/channels/{int(channel_id)}"), reason=reason)
240+
await self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id), reason=reason)
233241

234242
async def get_channel_invites(self, channel_id: "Snowflake_Type") -> list[discord_typings.InviteData]:
235243
"""
@@ -242,7 +250,7 @@ async def get_channel_invites(self, channel_id: "Snowflake_Type") -> list[discor
242250
List of invite objects
243251
244252
"""
245-
result = await self.request(Route("GET", f"/channels/{int(channel_id)}/invites"))
253+
result = await self.request(Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id))
246254
return cast(list[discord_typings.InviteData], result)
247255

248256
@overload
@@ -336,7 +344,7 @@ async def create_channel_invite(
336344
payload = dict_filter_none(payload)
337345

338346
result = await self.request(
339-
Route("POST", f"/channels/{int(channel_id)}/invites"), payload=payload, reason=reason
347+
Route("POST", "/channels/{channel_id}/invites", channel_id=channel_id), payload=payload, reason=reason
340348
)
341349
return cast(discord_typings.InviteData, result)
342350

@@ -361,13 +369,14 @@ async def get_invite(
361369
362370
"""
363371
params: PAYLOAD_TYPE = {
372+
"invite_code": invite_code,
364373
"with_counts": with_counts,
365374
"with_expiration": with_expiration,
366375
"guild_scheduled_event_id": int(scheduled_event_id) if scheduled_event_id else None,
367376
}
368377
params = dict_filter_none(params)
369378

370-
result = await self.request(Route("GET", f"/invites/{invite_code}", params=params))
379+
result = await self.request(Route("GET", "/invites/{invite_code}", params=params))
371380
return cast(discord_typings.InviteData, result)
372381

373382
async def delete_invite(self, invite_code: str, reason: str | None = None) -> discord_typings.InviteData:
@@ -382,7 +391,7 @@ async def delete_invite(self, invite_code: str, reason: str | None = None) -> di
382391
The deleted invite object
383392
384393
"""
385-
result = await self.request(Route("DELETE", f"/invites/{invite_code}"), reason=reason)
394+
result = await self.request(Route("DELETE", "/invites/{invite_code}", invite_code=invite_code), reason=reason)
386395
return cast(discord_typings.InviteData, result)
387396

388397
async def edit_channel_permission(
@@ -409,7 +418,12 @@ async def edit_channel_permission(
409418
payload: PAYLOAD_TYPE = {"allow": allow, "deny": deny, "type": perm_type}
410419

411420
await self.request(
412-
Route("PUT", f"/channels/{int(channel_id)}/permissions/{int(overwrite_id)}"),
421+
Route(
422+
"PUT",
423+
"/channels/{channel_id}/permissions/{overwrite_id}",
424+
channel_id=channel_id,
425+
overwrite_id=overwrite_id,
426+
),
413427
payload=payload,
414428
reason=reason,
415429
)
@@ -429,7 +443,10 @@ async def delete_channel_permission(
429443
reason: An optional reason for the audit log
430444
431445
"""
432-
await self.request(Route("DELETE", f"/channels/{int(channel_id)}/{int(overwrite_id)}"), reason=reason)
446+
await self.request(
447+
Route("DELETE", "/channels/{channel_id}/{overwrite_id}", channel_id=channel_id, overwrite_id=overwrite_id),
448+
reason=reason,
449+
)
433450

434451
async def follow_news_channel(
435452
self, channel_id: "Snowflake_Type", webhook_channel_id: "Snowflake_Type"
@@ -447,7 +464,9 @@ async def follow_news_channel(
447464
"""
448465
payload = {"webhook_channel_id": int(webhook_channel_id)}
449466

450-
result = await self.request(Route("POST", f"/channels/{int(channel_id)}/followers"), payload=payload)
467+
result = await self.request(
468+
Route("POST", "/channels/{channel_id}/followers", channel_id=channel_id), payload=payload
469+
)
451470
return cast(discord_typings.FollowedChannelData, result)
452471

453472
async def trigger_typing_indicator(self, channel_id: "Snowflake_Type") -> None:
@@ -458,7 +477,7 @@ async def trigger_typing_indicator(self, channel_id: "Snowflake_Type") -> None:
458477
channel_id: The id of the channel to "type" in
459478
460479
"""
461-
await self.request(Route("POST", f"/channels/{int(channel_id)}/typing"))
480+
await self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id))
462481

463482
async def get_pinned_messages(self, channel_id: "Snowflake_Type") -> list[discord_typings.MessageData]:
464483
"""
@@ -471,7 +490,7 @@ async def get_pinned_messages(self, channel_id: "Snowflake_Type") -> list[discor
471490
A list of pinned message objects
472491
473492
"""
474-
result = await self.request(Route("GET", f"/channels/{int(channel_id)}/pins"))
493+
result = await self.request(Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id))
475494
return cast(list[discord_typings.MessageData], result)
476495

477496
async def create_stage_instance(
@@ -514,7 +533,7 @@ async def get_stage_instance(self, channel_id: "Snowflake_Type") -> discord_typi
514533
A stage instance.
515534
516535
"""
517-
result = await self.request(Route("GET", f"/stage-instances/{int(channel_id)}"))
536+
result = await self.request(Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id))
518537
return cast(discord_typings.StageInstanceData, result)
519538

520539
async def modify_stage_instance(
@@ -540,7 +559,7 @@ async def modify_stage_instance(
540559
payload: PAYLOAD_TYPE = {"topic": topic, "privacy_level": privacy_level}
541560
payload = dict_filter_none(payload)
542561
result = await self.request(
543-
Route("PATCH", f"/stage-instances/{int(channel_id)}"), payload=payload, reason=reason
562+
Route("PATCH", "/stage-instances/{channel_id}", channel_id=channel_id), payload=payload, reason=reason
544563
)
545564
return cast(discord_typings.StageInstanceData, result)
546565

@@ -553,7 +572,7 @@ async def delete_stage_instance(self, channel_id: "Snowflake_Type", reason: str
553572
reason: The reason for the deletion
554573
555574
"""
556-
await self.request(Route("DELETE", f"/stage-instances/{int(channel_id)}"), reason=reason)
575+
await self.request(Route("DELETE", "/stage-instances/{channel_id}", channel_id=channel_id), reason=reason)
557576

558577
async def create_tag(
559578
self,
@@ -582,7 +601,9 @@ async def create_tag(
582601
}
583602
payload = dict_filter_none(payload)
584603

585-
result = await self.request(Route("POST", f"/channels/{int(channel_id)}/tags"), payload=payload)
604+
result = await self.request(
605+
Route("POST", "/channels/{channel_id}/tags", channel_id=channel_id), payload=payload
606+
)
586607
return cast(discord_typings.ChannelData, result)
587608

588609
async def edit_tag(
@@ -614,7 +635,9 @@ async def edit_tag(
614635
}
615636
payload = dict_filter_none(payload)
616637

617-
result = await self.request(Route("PUT", f"/channels/{int(channel_id)}/tags/{int(tag_id)}"), payload=payload)
638+
result = await self.request(
639+
Route("PUT", "/channels/{channel_id}/tags/{tag_id}", channel_id=channel_id, tag_id=tag_id), payload=payload
640+
)
618641
return cast(discord_typings.ChannelData, result)
619642

620643
async def delete_tag(self, channel_id: "Snowflake_Type", tag_id: "Snowflake_Type") -> discord_typings.ChannelData:
@@ -625,5 +648,7 @@ async def delete_tag(self, channel_id: "Snowflake_Type", tag_id: "Snowflake_Type
625648
channel_id: The ID of the forum channel to delete tag it.
626649
tag_id: The ID of the tag to delete
627650
"""
628-
result = await self.request(Route("DELETE", f"/channels/{int(channel_id)}/tags/{int(tag_id)}"))
651+
result = await self.request(
652+
Route("DELETE", "/channels/{channel_id}/tags/{tag_id}", channel_id=channel_id, tag_id=tag_id)
653+
)
629654
return cast(discord_typings.ChannelData, result)

0 commit comments

Comments
 (0)