@@ -698,22 +698,41 @@ def test_query_list_eq_numeric_comparison() -> None:
698
698
) # Should not be equal since values are different
699
699
700
700
701
- def test_keygetter_nested_objects () -> None :
702
- """Test keygetter function with nested objects."""
701
+ @dataclasses .dataclass
702
+ class Food (t .Mapping [str , t .Any ]):
703
+ fruit : list [str ] = dataclasses .field (default_factory = list )
704
+ breakfast : str | None = None
703
705
704
- @dataclasses .dataclass
705
- class Food :
706
- fruit : list [str ] = dataclasses .field (default_factory = list )
707
- breakfast : str | None = None
706
+ def __getitem__ (self , key : str ) -> t .Any :
707
+ return getattr (self , key )
708
708
709
- @dataclasses .dataclass
710
- class Restaurant :
711
- place : str
712
- city : str
713
- state : str
714
- food : Food = dataclasses .field (default_factory = Food )
709
+ def __iter__ (self ) -> t .Iterator [str ]:
710
+ return iter (self .__dataclass_fields__ )
711
+
712
+ def __len__ (self ) -> int :
713
+ return len (self .__dataclass_fields__ )
714
+
715
+
716
+ @dataclasses .dataclass
717
+ class Restaurant (t .Mapping [str , t .Any ]):
718
+ place : str
719
+ city : str
720
+ state : str
721
+ food : Food = dataclasses .field (default_factory = Food )
722
+
723
+ def __getitem__ (self , key : str ) -> t .Any :
724
+ return getattr (self , key )
715
725
716
- # Test with nested dataclass
726
+ def __iter__ (self ) -> t .Iterator [str ]:
727
+ return iter (self .__dataclass_fields__ )
728
+
729
+ def __len__ (self ) -> int :
730
+ return len (self .__dataclass_fields__ )
731
+
732
+
733
+ def test_keygetter_nested_objects () -> None :
734
+ """Test keygetter function with nested objects."""
735
+ # Test with nested dataclass that implements Mapping protocol
717
736
restaurant = Restaurant (
718
737
place = "Largo" ,
719
738
city = "Tampa" ,
@@ -736,7 +755,9 @@ class Restaurant:
736
755
737
756
# Test with non-mapping object (returns the object itself)
738
757
non_mapping = "not a mapping"
739
- assert keygetter (non_mapping , "any_key" ) == non_mapping # type: ignore
758
+ assert (
759
+ keygetter (t .cast (t .Mapping [str , t .Any ], non_mapping ), "any_key" ) == non_mapping
760
+ )
740
761
741
762
742
763
def test_query_list_slicing () -> None :
@@ -773,24 +794,33 @@ def test_query_list_attributes() -> None:
773
794
774
795
# Test pk_key attribute with objects
775
796
@dataclasses .dataclass
776
- class Item :
797
+ class Item ( t . Mapping [ str , t . Any ]) :
777
798
id : str
778
799
value : int
779
800
801
+ def __getitem__ (self , key : str ) -> t .Any :
802
+ return getattr (self , key )
803
+
804
+ def __iter__ (self ) -> t .Iterator [str ]:
805
+ return iter (self .__dataclass_fields__ )
806
+
807
+ def __len__ (self ) -> int :
808
+ return len (self .__dataclass_fields__ )
809
+
780
810
items = [Item ("1" , 1 ), Item ("2" , 2 )]
781
- ql = QueryList (items )
782
- ql .pk_key = "id"
783
- assert ql .items () == [("1" , items [0 ]), ("2" , items [1 ])]
811
+ ql_items : QueryList [ t . Any ] = QueryList (items )
812
+ ql_items .pk_key = "id"
813
+ assert list ( ql_items .items () ) == [("1" , items [0 ]), ("2" , items [1 ])]
784
814
785
815
# Test pk_key with non-existent attribute
786
- ql .pk_key = "nonexistent"
816
+ ql_items .pk_key = "nonexistent"
787
817
with pytest .raises (AttributeError ):
788
- ql .items ()
818
+ ql_items .items ()
789
819
790
820
# Test pk_key with None
791
- ql .pk_key = None
821
+ ql_items .pk_key = None
792
822
with pytest .raises (PKRequiredException ):
793
- ql .items ()
823
+ ql_items .items ()
794
824
795
825
796
826
def test_lookup_name_map () -> None :
0 commit comments