Skip to content

Commit 95a822c

Browse files
EepyElvyraToricane
andauthored
refactor: Modify the cache with a merge method to ensure no important data gets overwritten (#913)
* refactor: add merge method for cache * refactor: Optimize `_search_iterable`and use a list of tuples instead of a tuple of tuples * refactor: Do a guild updates in events * refactor: Add caching to several HTTP methods * Update interactions/api/http/guild.py Co-authored-by: Toricane <73972068+Toricane@users.noreply.github.com> * Update interactions/api/http/channel.py Co-authored-by: Toricane <73972068+Toricane@users.noreply.github.com> Co-authored-by: Toricane <73972068+Toricane@users.noreply.github.com>
1 parent da341e3 commit 95a822c

File tree

9 files changed

+161
-24
lines changed

9 files changed

+161
-24
lines changed

interactions/api/cache.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,39 @@ def __repr__(self) -> str:
4141
def __init__(self) -> None:
4242
self.values: Dict["Key", _T] = {}
4343

44+
def merge(self, item: _T, id: Optional["Key"] = None) -> None:
45+
"""
46+
Merges new data of an item into an already present item of the cache
47+
48+
:param item: The item to merge.
49+
:type item: Any
50+
:param id: The unique id of the item.
51+
:type id: Optional[Union[Snowflake, Tuple[Snowflake, Snowflake]]]
52+
"""
53+
if not self.values.get(id or item.id):
54+
return self.add(item, id)
55+
56+
_id = id or item.id
57+
old_item = self.values[_id]
58+
59+
for attrib in item.__slots__:
60+
if getattr(old_item, attrib) and not getattr(item, attrib):
61+
continue
62+
# we can only assume that discord did not provide it, falsely deleting is worse than not deleting
63+
if getattr(old_item, attrib) != getattr(item, attrib):
64+
if isinstance(item.attrib, list) and not isinstance(
65+
old_item.attrib, list
66+
): # could be None
67+
old_item.attrib = []
68+
if isinstance(old_item.attrib, list):
69+
for value in item.attrib:
70+
if value not in old_item.attrib:
71+
old_item.attrib.append(value)
72+
else:
73+
setattr(old_item, attrib, item.attrib)
74+
75+
self.values[_id] = old_item
76+
4477
def add(self, item: _T, id: Optional["Key"] = None) -> None:
4578
"""
4679
Adds a new item to the storage.
@@ -91,10 +124,7 @@ def pop(self, key: "Key", default: _P) -> Union[_T, _P]:
91124
...
92125

93126
def pop(self, key: "Key", default: Optional[_P] = None) -> Union[_T, _P, None]:
94-
try:
95-
return self.values.pop(key)
96-
except KeyError:
97-
return default
127+
return self.values.pop(key, default)
98128

99129
@property
100130
def view(self) -> List[dict]:

interactions/api/gateway/client.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from ..http.client import HTTPClient
2929
from ..models.attrs_utils import MISSING
3030
from ..models.flags import Intents
31+
from ..models.guild import Guild
3132
from ..models.member import Member
3233
from ..models.misc import Snowflake
3334
from ..models.presence import ClientPresence
@@ -400,7 +401,19 @@ def _dispatch_event(self, event: str, data: dict) -> None:
400401
id = getattr(obj, "id", None)
401402

402403
if "_create" in name or "_add" in name:
403-
_cache.add(obj, id)
404+
_cache.merge(obj, id)
405+
if guild_id := data.get("guild_id") and not isinstance(obj, Guild):
406+
guild = self._http.cache[Guild].get(Snowflake(guild_id))
407+
model_name = model.__name__.lower()
408+
_obj = getattr(guild, f"{model_name}s", None)
409+
if _obj is not None:
410+
if isinstance(_obj, list):
411+
_obj.append(obj)
412+
setattr(guild, f"{model_name}s", _obj)
413+
else:
414+
_obj = [obj]
415+
setattr(guild, f"{model_name}s", _obj)
416+
self._http.cache[Guild].add(guild)
404417
self._dispatch.dispatch(f"on_{name}", obj)
405418

406419
elif "_update" in name and hasattr(obj, "id"):
@@ -415,6 +428,23 @@ def _dispatch_event(self, event: str, data: dict) -> None:
415428

416429
_cache.add(old_obj, id)
417430

431+
if guild_id := data.get("guild_id") and not isinstance(obj, Guild):
432+
guild = self._http.cache[Guild].get(Snowflake(guild_id))
433+
model_name = model.__name__.lower()
434+
_obj = getattr(guild, f"{model_name}s", None)
435+
if _obj is not None:
436+
if isinstance(_obj, list):
437+
for __obj in _obj:
438+
if __obj.id == obj.id:
439+
_obj.remove(__obj)
440+
break
441+
_obj.append(obj)
442+
setattr(guild, f"{model_name}s", _obj)
443+
else:
444+
_obj = [obj]
445+
setattr(guild, f"{model_name}s", _obj)
446+
self._http.cache[Guild].add(guild)
447+
418448
self._dispatch.dispatch(
419449
f"on_{name}", before, old_obj
420450
) # give previously stored and new one
@@ -423,6 +453,19 @@ def _dispatch_event(self, event: str, data: dict) -> None:
423453
elif "_remove" in name or "_delete" in name:
424454
self._dispatch.dispatch(f"on_raw_{name}", obj)
425455

456+
if guild_id := data.get("guild_id") and not isinstance(obj, Guild):
457+
guild = self._http.cache[Guild].get(Snowflake(guild_id))
458+
model_name = model.__name__.lower()
459+
_obj = getattr(guild, f"{model_name}s", None)
460+
if _obj is not None:
461+
if isinstance(_obj, list):
462+
for __obj in _obj:
463+
if __obj.id == obj.id:
464+
_obj.remove(__obj)
465+
break
466+
setattr(guild, f"{model_name}s", _obj)
467+
self._http.cache[Guild].add(guild)
468+
426469
old_obj = _cache.pop(id)
427470
self._dispatch.dispatch(f"on_{name}", old_obj)
428471

interactions/api/http/channel.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ async def get_channel(self, channel_id: int) -> dict:
2525
:return: Dictionary of the channel object.
2626
"""
2727
request = await self._req.request(Route("GET", f"/channels/{channel_id}"))
28-
self.cache[Channel].add(Channel(**request, _client=self))
28+
self.cache[Channel].merge(Channel(**request, _client=self))
2929

3030
return request
3131

@@ -87,7 +87,7 @@ async def get_channel_messages(
8787
if isinstance(request, list):
8888
for message in request:
8989
if message.get("id"):
90-
self.cache[Message].add(Message(**message))
90+
self.cache[Message].merge(Message(**message, _client=self))
9191

9292
return request
9393

@@ -108,8 +108,6 @@ async def create_channel(
108108
request = await self._req.request(
109109
Route("POST", f"/guilds/{guild_id}/channels"), json=payload, reason=reason
110110
)
111-
if request.get("id"):
112-
self.cache[Channel].add(Channel(**request))
113111

114112
return request
115113

interactions/api/http/emoji.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from typing import List, Optional
22

33
from ...api.cache import Cache
4+
from ...api.models.guild import Emoji, Guild
5+
from ...api.models.misc import Snowflake
46
from .request import _Request
57
from .route import Route
68

@@ -22,7 +24,11 @@ async def get_all_emoji(self, guild_id: int) -> List[dict]:
2224
:param guild_id: Guild ID snowflake.
2325
:return: A list of emojis.
2426
"""
25-
return await self._req.request(Route("GET", f"/guilds/{guild_id}/emojis"))
27+
res = await self._req.request(Route("GET", f"/guilds/{guild_id}/emojis"))
28+
self.cache[Guild].get(Snowflake(guild_id)).emojis = [
29+
Emoji(**_res, _client=self) for _res in res
30+
]
31+
return res
2632

2733
async def get_guild_emoji(self, guild_id: int, emoji_id: int) -> dict:
2834
"""
@@ -32,7 +38,20 @@ async def get_guild_emoji(self, guild_id: int, emoji_id: int) -> dict:
3238
:param emoji_id: Emoji ID snowflake.
3339
:return: Emoji object
3440
"""
35-
return await self._req.request(Route("GET", f"/guilds/{guild_id}/emojis/{emoji_id}"))
41+
res = await self._req.request(Route("GET", f"/guilds/{guild_id}/emojis/{emoji_id}"))
42+
emoji = Emoji(**res, _client=self)
43+
guild = self.cache[Guild].get(Snowflake(guild_id))
44+
if guild.emojis is None:
45+
guild.emojis = [emoji]
46+
else:
47+
for index, _emoji in enumerate(guild.emojis):
48+
if _emoji.id == emoji.id:
49+
guild.emojis[index] = emoji
50+
break
51+
else:
52+
guild.emojis.append(emoji)
53+
self.cache[Guild].add(guild) # yes it should just be overwritten
54+
return res
3655

3756
async def create_guild_emoji(
3857
self, guild_id: int, payload: dict, reason: Optional[str] = None

interactions/api/http/guild.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from ...api.cache import Cache
55
from ..models.channel import Channel
66
from ..models.guild import Guild
7-
from ..models.member import Member
87
from ..models.role import Role
98
from .request import _Request
109
from .route import Route
@@ -30,7 +29,7 @@ async def get_self_guilds(self) -> List[dict]:
3029

3130
for guild in request:
3231
if guild.get("id"):
33-
self.cache[Guild].add(Guild(**guild, _client=self))
32+
self.cache[Guild].merge(Guild(**guild, _client=self))
3433

3534
return request
3635

@@ -45,7 +44,7 @@ async def get_guild(self, guild_id: int, with_counts: bool = False) -> dict:
4544
request = await self._req.request(
4645
Route("GET", f"/guilds/{guild_id}{f'?{with_counts=}' if with_counts else ''}")
4746
)
48-
self.cache[Guild].add(Guild(**request, _client=self))
47+
self.cache[Guild].merge(Guild(**request, _client=self))
4948

5049
return request
5150

@@ -369,7 +368,7 @@ async def get_all_channels(self, guild_id: int) -> List[dict]:
369368

370369
for channel in request:
371370
if channel.get("id"):
372-
self.cache[Channel].add(Channel(**channel, _client=self))
371+
self.cache[Channel].merge(Channel(**channel, _client=self))
373372

374373
return request
375374

@@ -386,7 +385,7 @@ async def get_all_roles(self, guild_id: int) -> List[dict]:
386385

387386
for role in request:
388387
if role.get("id"):
389-
self.cache[Role].add(Role(**role))
388+
self.cache[Role].merge(Role(**role))
390389

391390
return request
392391

@@ -404,8 +403,6 @@ async def create_guild_role(
404403
request = await self._req.request(
405404
Route("POST", f"/guilds/{guild_id}/roles"), json=payload, reason=reason
406405
)
407-
if request.get("id"):
408-
self.cache[Role].add(Role(**request))
409406

410407
return request
411408

@@ -588,7 +585,6 @@ async def add_guild_member(
588585
},
589586
)
590587

591-
self.cache[Member].add(Member(**request))
592588

593589
return request
594590

interactions/api/http/member.py

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
from typing import List, Optional
22

33
from ...api.cache import Cache
4+
from ...api.models.guild import Guild
5+
from ...api.models.member import Member
6+
from ...api.models.misc import Snowflake
47
from .request import _Request
58
from .route import Route
69

@@ -23,7 +26,7 @@ async def get_member(self, guild_id: int, member_id: int) -> Optional[dict]:
2326
:param member_id: Member ID snowflake.
2427
:return: A member object, if any.
2528
"""
26-
return await self._req.request(
29+
res = await self._req.request(
2730
Route(
2831
"GET",
2932
"/guilds/{guild_id}/members/{member_id}",
@@ -32,6 +35,21 @@ async def get_member(self, guild_id: int, member_id: int) -> Optional[dict]:
3235
)
3336
)
3437

38+
member = Member(**res, _client=self)
39+
guild = self.cache[Guild].get(Snowflake(guild_id))
40+
if guild.members is None:
41+
guild.members = [member]
42+
else:
43+
for index, _member in enumerate(guild.members):
44+
if _member.id == member.id:
45+
guild.members[index] = member
46+
break
47+
else:
48+
guild.members.append(member)
49+
self.cache[Guild].add(guild) # yes it should just be overwritten
50+
51+
return res
52+
3553
async def get_list_of_members(
3654
self, guild_id: int, limit: int = 1, after: Optional[int] = None
3755
) -> List[dict]:
@@ -47,7 +65,20 @@ async def get_list_of_members(
4765
if after:
4866
payload["after"] = after
4967

50-
return await self._req.request(Route("GET", f"/guilds/{guild_id}/members"), params=payload)
68+
res = await self._req.request(Route("GET", f"/guilds/{guild_id}/members"), params=payload)
69+
guild = self.cache[Guild].get(Snowflake(guild_id))
70+
if guild.members is None:
71+
guild.members = [Member(**_res, _client=self) for _res in res]
72+
else:
73+
for member in [Member(**_res, _client=self) for _res in res]:
74+
for index, _member in enumerate(guild.members):
75+
if _member.id == member.id:
76+
guild.members[index] = member
77+
break
78+
else:
79+
guild.members.append(member)
80+
self.cache[Guild].add(guild) # yes it should just be overwritten
81+
return res
5182

5283
async def search_guild_members(self, guild_id: int, query: str, limit: int = 1) -> List[dict]:
5384
"""

interactions/api/http/message.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -110,9 +110,11 @@ async def get_message(self, channel_id: int, message_id: int) -> Optional[dict]:
110110
:param message_id: the id of the message
111111
:return: message if it exists.
112112
"""
113-
return await self._req.request(
114-
Route("GET", f"/channels/{channel_id}/messages/{message_id}")
115-
)
113+
res = await self._req.request(Route("GET", f"/channels/{channel_id}/messages/{message_id}"))
114+
115+
self.cache[Message].merge(Message(**res, _client=self))
116+
117+
return res
116118

117119
async def delete_message(
118120
self, channel_id: int, message_id: int, reason: Optional[str] = None

interactions/api/models/channel.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,11 @@ class Channel(ClientSerializerMixin, IDMixin):
166166
permissions: Optional[str] = field(default=None)
167167
flags: Optional[int] = field(default=None)
168168

169+
def __attrs_post_init__(self): # sourcery skip: last-if-guard
170+
if self._client:
171+
if not self.recipients:
172+
self.recipients = self._client.cache[Channel].get(self.id).recipients
173+
169174
def __repr__(self) -> str:
170175
return self.name
171176

interactions/api/models/guild.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -345,6 +345,19 @@ def __attrs_post_init__(self): # sourcery skip: last-if-guard
345345
if self.members:
346346
self._client.cache[Member].update({(self.id, m.id): m for m in self.members})
347347

348+
if not self.channels:
349+
self.channels = self._client.cache[Guild].get(self.id).channels
350+
if not self.threads:
351+
self.threads = self._client.cache[Guild].get(self.id).threads
352+
if not self.roles:
353+
self.roles = self._client.cache[Guild].get(self.id).roles
354+
if not self.members:
355+
self.members = self._client.cache[Guild].get(self.id).members
356+
if not self.member_count:
357+
self.member_count = self._client.cache[Guild].get(self.id).member_count
358+
if not self.presences:
359+
self.presences = self._client.cache[Guild].get(self.id).presences
360+
348361
def __repr__(self) -> str:
349362
return self.name
350363

0 commit comments

Comments
 (0)