Skip to content

Commit 47319ed

Browse files
authored
feat: caching improvements (#1350)
* feat: add support for force fetching * feat: track if a user object has been fetched * feat: add force flag to client helper methods * feat: update all cache fetch methods to have a force param
1 parent 043d79e commit 47319ed

File tree

11 files changed

+156
-81
lines changed

11 files changed

+156
-81
lines changed

interactions/client/client.py

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1995,7 +1995,7 @@ def reload_extension(
19951995
sys.modules.pop(name, None)
19961996
raise ex from e
19971997

1998-
async def fetch_guild(self, guild_id: "Snowflake_Type") -> Optional[Guild]:
1998+
async def fetch_guild(self, guild_id: "Snowflake_Type", *, force: bool = False) -> Optional[Guild]:
19991999
"""
20002000
Fetch a guild.
20012001
@@ -2005,13 +2005,14 @@ async def fetch_guild(self, guild_id: "Snowflake_Type") -> Optional[Guild]:
20052005
20062006
Args:
20072007
guild_id: The ID of the guild to get
2008+
force: Whether to poll the API regardless of cache
20082009
20092010
Returns:
20102011
Guild Object if found, otherwise None
20112012
20122013
"""
20132014
try:
2014-
return await self.cache.fetch_guild(guild_id)
2015+
return await self.cache.fetch_guild(guild_id, force=force)
20152016
except NotFound:
20162017
return None
20172018

@@ -2060,7 +2061,7 @@ async def create_guild_from_template(
20602061
guild_data = await self.http.create_guild_from_guild_template(template_code, name, icon)
20612062
return Guild.from_dict(guild_data, self)
20622063

2063-
async def fetch_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_ALL_CHANNEL"]:
2064+
async def fetch_channel(self, channel_id: "Snowflake_Type", *, force: bool = False) -> Optional["TYPE_ALL_CHANNEL"]:
20642065
"""
20652066
Fetch a channel.
20662067
@@ -2070,13 +2071,14 @@ async def fetch_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_AL
20702071
20712072
Args:
20722073
channel_id: The ID of the channel to get
2074+
force: Whether to poll the API regardless of cache
20732075
20742076
Returns:
20752077
Channel Object if found, otherwise None
20762078
20772079
"""
20782080
try:
2079-
return await self.cache.fetch_channel(channel_id)
2081+
return await self.cache.fetch_channel(channel_id, force=force)
20802082
except NotFound:
20812083
return None
20822084

@@ -2096,7 +2098,7 @@ def get_channel(self, channel_id: "Snowflake_Type") -> Optional["TYPE_ALL_CHANNE
20962098
"""
20972099
return self.cache.get_channel(channel_id)
20982100

2099-
async def fetch_user(self, user_id: "Snowflake_Type") -> Optional[User]:
2101+
async def fetch_user(self, user_id: "Snowflake_Type", *, force: bool = False) -> Optional[User]:
21002102
"""
21012103
Fetch a user.
21022104
@@ -2106,13 +2108,14 @@ async def fetch_user(self, user_id: "Snowflake_Type") -> Optional[User]:
21062108
21072109
Args:
21082110
user_id: The ID of the user to get
2111+
force: Whether to poll the API regardless of cache
21092112
21102113
Returns:
21112114
User Object if found, otherwise None
21122115
21132116
"""
21142117
try:
2115-
return await self.cache.fetch_user(user_id)
2118+
return await self.cache.fetch_user(user_id, force=force)
21162119
except NotFound:
21172120
return None
21182121

@@ -2132,7 +2135,9 @@ def get_user(self, user_id: "Snowflake_Type") -> Optional[User]:
21322135
"""
21332136
return self.cache.get_user(user_id)
21342137

2135-
async def fetch_member(self, user_id: "Snowflake_Type", guild_id: "Snowflake_Type") -> Optional[Member]:
2138+
async def fetch_member(
2139+
self, user_id: "Snowflake_Type", guild_id: "Snowflake_Type", *, force: bool = False
2140+
) -> Optional[Member]:
21362141
"""
21372142
Fetch a member from a guild.
21382143
@@ -2143,13 +2148,14 @@ async def fetch_member(self, user_id: "Snowflake_Type", guild_id: "Snowflake_Typ
21432148
Args:
21442149
user_id: The ID of the member
21452150
guild_id: The ID of the guild to get the member from
2151+
force: Whether to poll the API regardless of cache
21462152
21472153
Returns:
21482154
Member object if found, otherwise None
21492155
21502156
"""
21512157
try:
2152-
return await self.cache.fetch_member(guild_id, user_id)
2158+
return await self.cache.fetch_member(guild_id, user_id, force=force)
21532159
except NotFound:
21542160
return None
21552161

@@ -2194,20 +2200,23 @@ async def fetch_scheduled_event(
21942200
except NotFound:
21952201
return None
21962202

2197-
async def fetch_custom_emoji(self, emoji_id: "Snowflake_Type", guild_id: "Snowflake_Type") -> Optional[CustomEmoji]:
2203+
async def fetch_custom_emoji(
2204+
self, emoji_id: "Snowflake_Type", guild_id: "Snowflake_Type", *, force: bool = False
2205+
) -> Optional[CustomEmoji]:
21982206
"""
21992207
Fetch a custom emoji by id.
22002208
22012209
Args:
22022210
emoji_id: The id of the custom emoji.
22032211
guild_id: The id of the guild the emoji belongs to.
2212+
force: Whether to poll the API regardless of cache.
22042213
22052214
Returns:
22062215
The custom emoji if found, otherwise None.
22072216
22082217
"""
22092218
try:
2210-
return await self.cache.fetch_emoji(guild_id, emoji_id)
2219+
return await self.cache.fetch_emoji(guild_id, emoji_id, force=force)
22112220
except NotFound:
22122221
return None
22132222

interactions/client/smart_cache.py

Lines changed: 34 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -99,12 +99,13 @@ def __attrs_post_init__(self) -> None:
9999

100100
# region User cache
101101

102-
async def fetch_user(self, user_id: "Snowflake_Type") -> User:
102+
async def fetch_user(self, user_id: "Snowflake_Type", *, force: bool = False) -> User:
103103
"""
104104
Fetch a user by their ID.
105105
106106
Args:
107107
user_id: The user's ID
108+
force: If the cache should be ignored, and the user should be fetched from the API
108109
109110
Returns:
110111
User object if found
@@ -113,9 +114,10 @@ async def fetch_user(self, user_id: "Snowflake_Type") -> User:
113114
user_id = to_snowflake(user_id)
114115

115116
user = self.user_cache.get(user_id)
116-
if user is None:
117+
if (user is None or user._fetched is False) or force:
117118
data = await self._client.http.get_user(user_id)
118119
user = self.place_user_data(data)
120+
user._fetched = True # the user object should set this to True, but we do it here just in case
119121
return user
120122

121123
def get_user(self, user_id: Optional["Snowflake_Type"]) -> Optional[User]:
@@ -164,13 +166,16 @@ def delete_user(self, user_id: "Snowflake_Type") -> None:
164166

165167
# region Member cache
166168

167-
async def fetch_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type") -> Member:
169+
async def fetch_member(
170+
self, guild_id: "Snowflake_Type", user_id: "Snowflake_Type", *, force: bool = False
171+
) -> Member:
168172
"""
169173
Fetch a member by their guild and user IDs.
170174
171175
Args:
172176
guild_id: The ID of the guild this user belongs to
173177
user_id: The ID of the user
178+
force: If the cache should be ignored, and the member should be fetched from the API
174179
175180
Returns:
176181
Member object if found
@@ -179,7 +184,7 @@ async def fetch_member(self, guild_id: "Snowflake_Type", user_id: "Snowflake_Typ
179184
guild_id = to_snowflake(guild_id)
180185
user_id = to_snowflake(user_id)
181186
member = self.member_cache.get((guild_id, user_id))
182-
if member is None:
187+
if member is None or force:
183188
data = await self._client.http.get_member(guild_id, user_id)
184189
member = self.place_member_data(guild_id, data)
185190
return member
@@ -323,15 +328,13 @@ async def is_user_in_guild(
323328

324329
return False
325330

326-
async def fetch_user_guild_ids(
327-
self,
328-
user_id: "Snowflake_Type",
329-
) -> List["Snowflake_Type"]:
331+
async def fetch_user_guild_ids(self, user_id: "Snowflake_Type") -> List["Snowflake_Type"]:
330332
"""
331333
Fetch a list of IDs for the guilds a user has joined.
332334
333335
Args:
334336
user_id: The ID of the user
337+
335338
Returns:
336339
A list of snowflakes for the guilds the client can see the user is within
337340
"""
@@ -361,16 +364,15 @@ def get_user_guild_ids(self, user_id: "Snowflake_Type") -> List["Snowflake_Type"
361364
# region Message cache
362365

363366
async def fetch_message(
364-
self,
365-
channel_id: "Snowflake_Type",
366-
message_id: "Snowflake_Type",
367+
self, channel_id: "Snowflake_Type", message_id: "Snowflake_Type", *, force: bool = False
367368
) -> Message:
368369
"""
369370
Fetch a message from a channel based on their IDs.
370371
371372
Args:
372373
channel_id: The ID of the channel the message is in
373374
message_id: The ID of the message
375+
force: If the cache should be ignored, and the message should be fetched from the API
374376
375377
Returns:
376378
The message if found
@@ -379,7 +381,7 @@ async def fetch_message(
379381
message_id = to_snowflake(message_id)
380382
message = self.message_cache.get((channel_id, message_id))
381383

382-
if message is None:
384+
if message is None or force:
383385
data = await self._client.http.get_message(channel_id, message_id)
384386
message = self.place_message_data(data)
385387
if message.channel is None:
@@ -437,22 +439,20 @@ def delete_message(self, channel_id: "Snowflake_Type", message_id: "Snowflake_Ty
437439
# endregion Message cache
438440

439441
# region Channel cache
440-
async def fetch_channel(
441-
self,
442-
channel_id: "Snowflake_Type",
443-
) -> "TYPE_ALL_CHANNEL":
442+
async def fetch_channel(self, channel_id: "Snowflake_Type", *, force: bool = False) -> "TYPE_ALL_CHANNEL":
444443
"""
445444
Get a channel based on its ID.
446445
447446
Args:
448447
channel_id: The ID of the channel
448+
force: If the cache should be ignored, and the channel should be fetched from the API
449449
450450
Returns:
451451
The channel if found
452452
"""
453453
channel_id = to_snowflake(channel_id)
454454
channel = self.channel_cache.get(channel_id)
455-
if channel is None:
455+
if channel is None or force:
456456
try:
457457
data = await self._client.http.get_channel(channel_id)
458458
channel = self.place_channel_data(data)
@@ -518,31 +518,33 @@ def place_dm_channel_id(self, user_id: "Snowflake_Type", channel_id: "Snowflake_
518518
"""
519519
self.dm_channels[to_snowflake(user_id)] = to_snowflake(channel_id)
520520

521-
async def fetch_dm_channel_id(self, user_id: "Snowflake_Type") -> "Snowflake_Type":
521+
async def fetch_dm_channel_id(self, user_id: "Snowflake_Type", *, force: bool = False) -> "Snowflake_Type":
522522
"""
523523
Get the DM channel ID for a user.
524524
525525
Args:
526526
user_id: The ID of the user
527+
force: If the cache should be ignored, and the channel should be fetched from the API
527528
"""
528529
user_id = to_snowflake(user_id)
529530
channel_id = self.dm_channels.get(user_id)
530-
if channel_id is None:
531+
if channel_id is None or force:
531532
data = await self._client.http.create_dm(user_id)
532533
channel = self.place_channel_data(data)
533534
channel_id = channel.id
534535
return channel_id
535536

536-
async def fetch_dm_channel(self, user_id: "Snowflake_Type") -> "DM":
537+
async def fetch_dm_channel(self, user_id: "Snowflake_Type", *, force: bool = False) -> "DM":
537538
"""
538539
Fetch the DM channel for a user.
539540
540541
Args:
541542
user_id: The ID of the user
543+
force: If the cache should be ignored, and the channel should be fetched from the API
542544
"""
543545
user_id = to_snowflake(user_id)
544-
channel_id = await self.fetch_dm_channel_id(user_id)
545-
return await self.fetch_channel(channel_id)
546+
channel_id = await self.fetch_dm_channel_id(user_id, force=force)
547+
return await self.fetch_channel(channel_id, force=force)
546548

547549
def get_dm_channel(self, user_id: Optional["Snowflake_Type"]) -> Optional["DM"]:
548550
"""
@@ -575,19 +577,20 @@ def delete_channel(self, channel_id: "Snowflake_Type") -> None:
575577

576578
# region Guild cache
577579

578-
async def fetch_guild(self, guild_id: "Snowflake_Type") -> Guild:
580+
async def fetch_guild(self, guild_id: "Snowflake_Type", *, force: bool = False) -> Guild:
579581
"""
580582
Fetch a guild based on its ID.
581583
582584
Args:
583585
guild_id: The ID of the guild
586+
force: If the cache should be ignored, and the guild should be fetched from the API
584587
585588
Returns:
586589
The guild if found
587590
"""
588591
guild_id = to_snowflake(guild_id)
589592
guild = self.guild_cache.get(guild_id)
590-
if guild is None:
593+
if guild is None or force:
591594
data = await self._client.http.get_guild(guild_id)
592595
guild = self.place_guild_data(data)
593596
return guild
@@ -648,21 +651,24 @@ async def fetch_role(
648651
self,
649652
guild_id: "Snowflake_Type",
650653
role_id: "Snowflake_Type",
654+
*,
655+
force: bool = False,
651656
) -> Role:
652657
"""
653658
Fetch a role based on the guild and its own ID.
654659
655660
Args:
656661
guild_id: The ID of the guild this role belongs to
657662
role_id: The ID of the role
663+
force: If the cache should be ignored, and the role should be fetched from the API
658664
659665
Returns:
660666
The role if found
661667
"""
662668
guild_id = to_snowflake(guild_id)
663669
role_id = to_snowflake(role_id)
664670
role = self.role_cache.get(role_id)
665-
if role is None:
671+
if role is None or force:
666672
data = await self._client.http.get_roles(guild_id)
667673
role = self.place_role_data(guild_id, data).get(role_id)
668674
return role
@@ -830,9 +836,7 @@ def delete_bot_voice_state(self, guild_id: "Snowflake_Type") -> None:
830836
# region Emoji cache
831837

832838
async def fetch_emoji(
833-
self,
834-
guild_id: "Snowflake_Type",
835-
emoji_id: "Snowflake_Type",
839+
self, guild_id: "Snowflake_Type", emoji_id: "Snowflake_Type", *, force: bool = False
836840
) -> "CustomEmoji":
837841
"""
838842
Fetch an emoji based on the guild and its own ID.
@@ -842,14 +846,15 @@ async def fetch_emoji(
842846
Args:
843847
guild_id: The ID of the guild this emoji belongs to
844848
emoji_id: The ID of the emoji
849+
force: If the cache should be ignored, and the emoji should be fetched from the API
845850
846851
Returns:
847852
The Emoji if found
848853
"""
849854
guild_id = to_snowflake(guild_id)
850855
emoji_id = to_snowflake(emoji_id)
851856
emoji = self.emoji_cache.get(emoji_id) if self.emoji_cache is not None else None
852-
if emoji is None:
857+
if emoji is None or force:
853858
data = await self._client.http.get_guild_emoji(guild_id, emoji_id)
854859
emoji = self.place_emoji_data(guild_id, data)
855860

0 commit comments

Comments
 (0)