5
5
import time
6
6
import logging
7
7
import warnings
8
+ import sys
8
9
from inspect import getmembers
9
- from typing import Any , Dict , Generic , Iterable , Iterator , List , Optional , Sequence , Mapping , Type , TypeVar , Text , \
10
- Tuple , Union , cast
10
+ from typing import Any
11
+ from typing import Dict
12
+ from typing import Generic
13
+ from typing import Iterable
14
+ from typing import Iterator
15
+ from typing import List
16
+ from typing import Mapping
17
+ from typing import Optional
18
+ from typing import Sequence
19
+ from typing import Text
20
+ from typing import Tuple
21
+ from typing import Type
22
+ from typing import TypeVar
23
+ from typing import Union
24
+ from typing import cast
25
+
26
+ if sys .version_info >= (3 , 8 ):
27
+ from typing import Protocol
28
+ else :
29
+ from typing_extensions import Protocol
11
30
12
31
from pynamodb .expressions .update import Action
13
32
from pynamodb .exceptions import DoesNotExist , TableDoesNotExist , TableError , InvalidStateError , PutError
@@ -151,7 +170,7 @@ def commit(self) -> None:
151
170
unprocessed_items = data .get (UNPROCESSED_ITEMS , {}).get (self .model .Meta .table_name )
152
171
153
172
154
- class MetaModel ( AttributeContainerMeta ):
173
+ class MetaProtocol ( Protocol ):
155
174
table_name : str
156
175
read_capacity_units : Optional [int ]
157
176
write_capacity_units : Optional [int ]
@@ -169,14 +188,17 @@ class MetaModel(AttributeContainerMeta):
169
188
billing_mode : Optional [str ]
170
189
stream_view_type : Optional [str ]
171
190
191
+
192
+ class MetaModel (AttributeContainerMeta ):
172
193
"""
173
194
Model meta class
174
-
175
- This class is just here so that index queries have nice syntax.
176
- Model.index.query()
177
195
"""
178
- def __init__ (self , name : str , bases : Any , attrs : Dict [str , Any ]) -> None :
179
- super ().__init__ (name , bases , attrs )
196
+ def __new__ (cls , name , bases , namespace , discriminator = None ):
197
+ # Defined so that the discriminator can be set in the class definition.
198
+ return super ().__new__ (cls , name , bases , namespace )
199
+
200
+ def __init__ (self , name , bases , namespace , discriminator = None ) -> None :
201
+ super ().__init__ (name , bases , namespace , discriminator )
180
202
cls = cast (Type ['Model' ], self )
181
203
for attr_name , attribute in cls .get_attributes ().items ():
182
204
if attribute .is_hash_key :
@@ -200,8 +222,8 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
200
222
raise ValueError ("{} has more than one TTL attribute: {}" .format (
201
223
cls .__name__ , ", " .join (ttl_attr_names )))
202
224
203
- if isinstance (attrs , dict ):
204
- for attr_name , attr_obj in attrs .items ():
225
+ if isinstance (namespace , dict ):
226
+ for attr_name , attr_obj in namespace .items ():
205
227
if attr_name == META_CLASS_NAME :
206
228
if not hasattr (attr_obj , REGION ):
207
229
setattr (attr_obj , REGION , get_settings_value ('region' ))
@@ -234,9 +256,9 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
234
256
235
257
# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
236
258
# so that "except Model.DoesNotExist:" would not catch other models' exceptions
237
- if 'DoesNotExist' not in attrs :
259
+ if 'DoesNotExist' not in namespace :
238
260
exception_attrs = {
239
- '__module__' : attrs .get ('__module__' ),
261
+ '__module__' : namespace .get ('__module__' ),
240
262
'__qualname__' : f'{ cls .__qualname__ } .{ "DoesNotExist" } ' ,
241
263
}
242
264
cls .DoesNotExist = type ('DoesNotExist' , (DoesNotExist , ), exception_attrs )
@@ -260,7 +282,7 @@ class Model(AttributeContainer, metaclass=MetaModel):
260
282
DoesNotExist : Type [DoesNotExist ] = DoesNotExist
261
283
_version_attribute_name : Optional [str ] = None
262
284
263
- Meta : MetaModel
285
+ Meta : MetaProtocol
264
286
265
287
def __init__ (
266
288
self ,
@@ -520,9 +542,7 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T:
520
542
if data is None :
521
543
raise ValueError ("Received no data to construct object" )
522
544
523
- model = cls (_user_instantiated = False )
524
- model .deserialize (data )
525
- return model
545
+ return cls ._instantiate (data )
526
546
527
547
@classmethod
528
548
def count (
@@ -556,6 +576,11 @@ def count(
556
576
else :
557
577
hash_key = cls ._serialize_keys (hash_key )[0 ]
558
578
579
+ # If this class has a discriminator value, filter the query to only return instances of this class.
580
+ discriminator_attr = cls ._get_discriminator_attribute ()
581
+ if discriminator_attr and discriminator_attr .get_discriminator (cls ):
582
+ filter_condition &= discriminator_attr == cls
583
+
559
584
query_args = (hash_key ,)
560
585
query_kwargs = dict (
561
586
range_key_condition = range_key_condition ,
@@ -616,6 +641,11 @@ def query(
616
641
else :
617
642
hash_key = cls ._serialize_keys (hash_key )[0 ]
618
643
644
+ # If this class has a discriminator value, filter the query to only return instances of this class.
645
+ discriminator_attr = cls ._get_discriminator_attribute ()
646
+ if discriminator_attr and discriminator_attr .get_discriminator (cls ):
647
+ filter_condition &= discriminator_attr == cls
648
+
619
649
if page_size is None :
620
650
page_size = limit
621
651
@@ -668,6 +698,11 @@ def scan(
668
698
:param rate_limit: If set then consumed capacity will be limited to this amount per second
669
699
:param attributes_to_get: If set, specifies the properties to include in the projection expression
670
700
"""
701
+ # If this class has a discriminator value, filter the scan to only return instances of this class.
702
+ discriminator_attr = cls ._get_discriminator_attribute ()
703
+ if discriminator_attr and discriminator_attr .get_discriminator (cls ):
704
+ filter_condition &= discriminator_attr == cls
705
+
671
706
if page_size is None :
672
707
page_size = limit
673
708
0 commit comments