32
32
from ..http .client import HTTPClient
33
33
from ..models .flags import Intents
34
34
from ..models .guild import Guild
35
+ from ..models .gw import GuildMember , GuildRole
35
36
from ..models .member import Member
37
+ from ..models .message import Message
36
38
from ..models .misc import Snowflake
37
39
from ..models .presence import ClientPresence
40
+ from ..models .role import Role
38
41
from .heartbeat import _Heartbeat
39
42
from .ratelimit import WSRateLimit
40
43
@@ -437,87 +440,64 @@ def _dispatch_event(self, event: str, data: dict) -> None:
437
440
elif event not in {"TYPING_START" , "VOICE_STATE_UPDATE" , "VOICE_SERVER_UPDATE" }:
438
441
name : str = event .lower ()
439
442
try :
443
+ data ["_client" ] = self ._http
440
444
441
445
_event_path : list = [section .capitalize () for section in name .split ("_" )]
442
446
_name : str = _event_path [0 ] if len (_event_path ) < 3 else "" .join (_event_path [:- 1 ])
443
447
model = getattr (__import__ (path ), _name )
444
-
445
- data ["_client" ] = self ._http
446
448
obj = model (** data )
447
449
448
- _cache : "Storage" = self ._http .cache [model ]
450
+ guild_obj = guild_model = None
451
+ if model is GuildRole :
452
+ guild_obj = Role (** role_data ) if (role_data := data .get ("role" )) else None
453
+ guild_model = Role
454
+ elif model is GuildMember :
455
+ guild_obj = Member (** data )
456
+ guild_model = Member
449
457
450
- if isinstance (obj , Member ):
451
- id = (Snowflake (data ["guild_id" ]), obj .id )
452
- else :
453
- id = getattr (obj , "id" , None )
458
+ _cache : "Storage" = self ._http .cache [model ]
459
+ _guild_cache : "Storage" = self ._http .cache [guild_model ]
454
460
461
+ ids = None
462
+ id = self .__get_object_id (data , obj , model )
455
463
if id is None :
456
- if model .__name__ == "GuildScheduledEventUser" :
457
- id = model .guild_scheduled_event_id
458
- elif model .__name__ == "Presence" :
459
- id = obj .user .id
460
- elif model .__name__ in [
461
- "Invite" ,
462
- "GuildBan" ,
463
- "ChannelPins" ,
464
- "MessageReaction" ,
465
- "MessageReactionRemove" ,
466
- "MessageDelete" ,
467
- # Extend this for everything that should not be cached
468
- ]:
469
- id = None
470
- elif model .__name__ .startswith ("Guild" ):
471
- model_name = model .__name__ [5 :]
472
- if _data := getattr (obj , model_name , None ):
473
- id = (
474
- getattr (_data , "id" )
475
- if not isinstance (_data , dict )
476
- else Snowflake (_data ["id" ])
477
- )
478
- elif hasattr (obj , f"{ model_name } _id" ):
479
- id = getattr (obj , f"{ model_name } _id" , None )
480
-
481
- def __modify_guild_cache ():
482
- if not (
483
- (guild_id := data .get ("guild_id" ))
484
- and not isinstance (obj , Guild )
485
- and "message" not in name
486
- and id is not None
487
- ):
488
- return
489
- if guild := self ._http .cache [Guild ].get (Snowflake (guild_id )):
490
- model_name : str = model .__name__
491
- if "guild" in model_name :
492
- model_name = model_name [5 :]
493
- elif model_name == "threadmembers" :
494
- return
495
- _obj = getattr (guild , f"{ model_name .lower ()} s" , None )
496
- if _obj is not None and isinstance (_obj , list ):
497
- if "_create" in name or "_add" in name :
498
- _obj .append (obj )
499
- for index , __obj in enumerate (_obj ):
500
- if __obj .id == id :
501
- if "_remove" in name or "_delete" in name :
502
- _obj .remove (__obj )
503
-
504
- elif "_update" in name and hasattr (obj , "id" ):
505
- _obj [index ] = obj
506
- break
507
- setattr (guild , f"{ model_name } s" , _obj )
508
- self ._http .cache [Guild ].add (guild )
464
+ ids = self .__get_object_ids (obj , model )
509
465
510
466
if "_create" in name or "_add" in name :
467
+ self ._dispatch .dispatch (f"on_{ name } " , obj )
468
+
511
469
if id :
512
470
_cache .merge (obj , id )
513
- self ._dispatch .dispatch (f"on_{ name } " , obj )
514
- __modify_guild_cache ()
471
+ if guild_obj :
472
+ _guild_cache .add (guild_obj , id )
473
+
474
+ self .__modify_guild_cache (
475
+ name , data , guild_model or model , guild_obj or obj , id , ids
476
+ )
515
477
516
478
elif "_update" in name :
517
479
self ._dispatch .dispatch (f"on_raw_{ name } " , obj )
518
- if not id :
480
+
481
+ if not id and ids is None :
482
+ return self ._dispatch .dispatch (f"on_{ name } " , obj )
483
+
484
+ self .__modify_guild_cache (
485
+ name , data , guild_model or model , guild_obj or obj , id , ids
486
+ )
487
+ if ids is not None :
488
+ # Not cached but it needed for guild_emojis_update and guild_stickers_update events
489
+ return self ._dispatch .dispatch (f"on_{ name } " , obj )
490
+ if id is None :
519
491
return
520
- old_obj = self ._http .cache [model ].get (id )
492
+
493
+ if guild_obj :
494
+ old_guild_obj = _guild_cache .get (id )
495
+ if old_guild_obj :
496
+ old_guild_obj .update (** guild_obj ._json )
497
+ else :
498
+ _guild_cache .add (guild_obj , id )
499
+
500
+ old_obj = _cache .get (id )
521
501
if old_obj :
522
502
before = model (** old_obj ._json )
523
503
old_obj .update (** obj ._json )
@@ -526,27 +506,163 @@ def __modify_guild_cache():
526
506
old_obj = obj
527
507
528
508
_cache .add (old_obj , id )
529
- __modify_guild_cache ()
530
-
531
509
self ._dispatch .dispatch (
532
510
f"on_{ name } " , before , old_obj
533
511
) # give previously stored and new one
534
512
535
513
elif "_remove" in name or "_delete" in name :
536
- self ._dispatch .dispatch (f"on_raw_{ name } " , obj )
537
- __modify_guild_cache ()
514
+ self ._dispatch .dispatch (
515
+ f"on_raw_{ name } " , obj
516
+ ) # Deprecated. Remove this in the future.
517
+
518
+ old_obj = None
538
519
if id :
520
+ _guild_cache .pop (id )
521
+ self .__modify_guild_cache (
522
+ name , data , guild_model or model , guild_obj or obj , id , ids
523
+ )
539
524
old_obj = _cache .pop (id )
540
- self ._dispatch .dispatch (f"on_{ name } " , old_obj )
541
- elif "_delete_bulk" in name :
542
- self ._dispatch .dispatch (f"on_{ name } " , obj )
525
+
526
+ elif ids is not None and "message" in name :
527
+ # currently only message has '_delete_bulk' event but ig better keep this condition for future.
528
+ _message_cache : "Storage" = self ._http .cache [Message ]
529
+ for message_id in ids :
530
+ _message_cache .pop (message_id )
531
+
532
+ self ._dispatch .dispatch (f"on_{ name } " , old_obj or obj )
543
533
544
534
else :
545
535
self ._dispatch .dispatch (f"on_{ name } " , obj )
546
536
547
537
except AttributeError as error :
548
538
log .warning (f"An error occurred dispatching { name } : { error } " )
549
539
540
+ def __get_object_id (
541
+ self , data : dict , obj : Any , model : Any
542
+ ) -> Optional [Union [Snowflake , Tuple [Snowflake , Snowflake ]]]:
543
+ """
544
+ Gets an ID from object.
545
+
546
+ :param data: The data for the event.
547
+ :type data: dict
548
+ :param obj: The object of the event.
549
+ :type obj: Any
550
+ :param model: The model of the event.
551
+ :type model: Any
552
+ :return: Object ID
553
+ :rtype: Optional[Union[Snowflake, Tuple[Snowflake, Snowflake]]]
554
+ """
555
+ if isinstance (obj , (Member , GuildMember )):
556
+ id = (Snowflake (data ["guild_id" ]), obj .id )
557
+ else :
558
+ id = getattr (obj , "id" , None )
559
+ if id is not None :
560
+ return id
561
+
562
+ if model .__name__ == "GuildScheduledEventUser" :
563
+ id = obj .guild_scheduled_event_id
564
+ elif model .__name__ == "Presence" :
565
+ id = obj .user .id
566
+ elif model .__name__ in [
567
+ "GuildBan" ,
568
+ # Extend this for everything that starts with 'Guild' and should not be cached
569
+ ]:
570
+ id = None
571
+ elif model .__name__ .startswith ("Guild" ):
572
+ model_name = model .__name__ [5 :].lower ()
573
+ if (_data := getattr (obj , model_name , None )) and not isinstance (_data , list ):
574
+ id = getattr (_data , "id" ) if not isinstance (_data , dict ) else Snowflake (_data ["id" ])
575
+ elif hasattr (obj , f"{ model_name } _id" ):
576
+ id = getattr (obj , f"{ model_name } _id" , None )
577
+
578
+ return id
579
+
580
+ def __get_object_ids (self , obj : Any , model : Any ) -> Optional [List [Snowflake ]]:
581
+ """
582
+ Gets a list of ids of object.
583
+
584
+ :param obj: The object of the event.
585
+ :type obj: Any
586
+ :param model: The model of the event.
587
+ :type model: Any
588
+ :return: Object IDs
589
+ :rtype: Optional[Union[Snowflake, Tuple[Snowflake, Snowflake]]]
590
+ """
591
+ ids = getattr (obj , "ids" , None )
592
+ if ids is not None :
593
+ return ids
594
+
595
+ if model .__name__ .startswith ("Guild" ):
596
+ model_name = model .__name__ [5 :].lower ()
597
+ if (_data := getattr (obj , model_name , None )) is not None :
598
+ ids = [
599
+ getattr (_obj , "id" ) if not isinstance (_obj , dict ) else Snowflake (_obj ["id" ])
600
+ for _obj in _data
601
+ ]
602
+
603
+ return ids
604
+
605
+ def __modify_guild_cache (
606
+ self ,
607
+ name : str ,
608
+ data : dict ,
609
+ model : Any ,
610
+ obj : Any ,
611
+ id : Optional [Snowflake ] = None ,
612
+ ids : Optional [List [Snowflake ]] = None ,
613
+ ):
614
+ """
615
+ Modifies guild cache.
616
+
617
+ :param event: The name of the event.
618
+ :type event: str
619
+ :param data: The data for the event.
620
+ :type data: dict
621
+ :param obj: The object of the event.
622
+ :type obj: Any
623
+ :param model: The model of the event.
624
+ :type model: Any
625
+ """
626
+ if not (
627
+ (guild_id := data .get ("guild_id" ))
628
+ and not isinstance (obj , Guild )
629
+ and "message" not in name
630
+ and (id is not None or ids is not None )
631
+ and (guild := self ._http .cache [Guild ].get (Snowflake (guild_id )))
632
+ ):
633
+ return
634
+
635
+ attr : str = model .__name__ .lower ()
636
+
637
+ if attr .startswith ("guild" ):
638
+ attr = attr [5 :]
639
+ if attr == "threadmembers" : # TODO: Figure out why this here
640
+ return
641
+ if not attr .endswith ("s" ):
642
+ attr = f"{ attr } s"
643
+ iterable = getattr (guild , attr , None )
644
+ if iterable is not None and isinstance (iterable , list ):
645
+ if "_create" in name or "_add" in name :
646
+ iterable .append (obj )
647
+ if id :
648
+ _id = id [1 ] if isinstance (id , tuple ) else id
649
+ for index , __obj in enumerate (iterable ):
650
+ if __obj .id == _id :
651
+ if "_remove" in name or "_delete" in name :
652
+ iterable .remove (__obj )
653
+
654
+ elif "_update" in name and hasattr (obj , "id" ):
655
+ iterable [index ] = obj
656
+ break
657
+ elif ids is not None and "_update" in name :
658
+ objs = getattr (obj , attr , None )
659
+ if objs is not None :
660
+ iterable .clear ()
661
+ iterable .extend (objs )
662
+ setattr (guild , attr , iterable )
663
+
664
+ self ._http .cache [Guild ].add (guild )
665
+
550
666
def __contextualize (self , data : dict ) -> "_Context" :
551
667
"""
552
668
Takes raw data given back from the Gateway
0 commit comments