13
13
from dateutil .tz import tzutc
14
14
from inspect import getfullargspec
15
15
from inspect import getmembers
16
- from typing import Any , Callable , Dict , Generic , List , Mapping , Optional , TypeVar , Type , Union , Set , overload
16
+ from typing import Any , Callable , Dict , Generic , List , Mapping , Optional , TypeVar , Type , Union , Set , cast , overload
17
17
from typing import TYPE_CHECKING
18
18
19
19
from pynamodb ._compat import GenericMeta
@@ -218,12 +218,16 @@ def delete(self, *values: Any) -> 'DeleteAction':
218
218
219
219
class AttributeContainerMeta (GenericMeta ):
220
220
221
- def __init__ (self , name , bases , attrs , * args , ** kwargs ):
222
- super ().__init__ (name , bases , attrs , * args , ** kwargs ) # type: ignore
223
- AttributeContainerMeta ._initialize_attributes (self )
221
+ def __new__ (cls , name , bases , namespace , discriminator = None ):
222
+ # Defined so that the discriminator can be set in the class definition.
223
+ return super ().__new__ (cls , name , bases , namespace )
224
+
225
+ def __init__ (self , name , bases , namespace , discriminator = None ):
226
+ super ().__init__ (name , bases , namespace )
227
+ AttributeContainerMeta ._initialize_attributes (self , discriminator )
224
228
225
229
@staticmethod
226
- def _initialize_attributes (cls ):
230
+ def _initialize_attributes (cls , discriminator_value ):
227
231
"""
228
232
Initialize attributes on the class.
229
233
"""
@@ -249,6 +253,20 @@ def _initialize_attributes(cls):
249
253
# Prepend the `attr_path` lists with the dynamo attribute name.
250
254
attribute ._update_attribute_paths (attribute .attr_name )
251
255
256
+ # Register the class with the discriminator if necessary.
257
+ discriminators = [name for name , attr in cls ._attributes .items () if isinstance (attr , DiscriminatorAttribute )]
258
+ if len (discriminators ) > 1 :
259
+ raise ValueError ("{} has more than one discriminator attribute: {}" .format (
260
+ cls .__name__ , ", " .join (discriminators )))
261
+ cls ._discriminator = discriminators [0 ] if discriminators else None
262
+ # TODO(jpinner) add support for model polymorphism
263
+ if cls ._discriminator and not issubclass (cls , MapAttribute ):
264
+ raise NotImplementedError ("Discriminators are not yet supported in model classes." )
265
+ if discriminator_value is not None :
266
+ if not cls ._discriminator :
267
+ raise ValueError ("{} does not have a discriminator attribute" .format (cls .__name__ ))
268
+ cls ._attributes [cls ._discriminator ].register_class (cls , discriminator_value )
269
+
252
270
253
271
class AttributeContainer (metaclass = AttributeContainerMeta ):
254
272
@@ -259,6 +277,7 @@ def __init__(self, _user_instantiated: bool = True, **attributes: Attribute) ->
259
277
# instances do not have any Attributes defined and instead use this dictionary to store their
260
278
# collection of name-value pairs.
261
279
self .attribute_values : Dict [str , Any ] = {}
280
+ self ._set_discriminator ()
262
281
self ._set_defaults (_user_instantiated = _user_instantiated )
263
282
self ._set_attributes (** attributes )
264
283
@@ -288,6 +307,15 @@ def _dynamo_to_python_attr(cls, dynamo_key: str) -> str:
288
307
"""
289
308
return cls ._dynamo_to_python_attrs .get (dynamo_key , dynamo_key ) # type: ignore
290
309
310
+ @classmethod
311
+ def _get_discriminator_attribute (cls ) -> Optional ['DiscriminatorAttribute' ]:
312
+ return cls .get_attributes ()[cls ._discriminator ] if cls ._discriminator else None # type: ignore
313
+
314
+ def _set_discriminator (self ) -> None :
315
+ discriminator_attr = self ._get_discriminator_attribute ()
316
+ if discriminator_attr and discriminator_attr .get_discriminator (self .__class__ ) is not None :
317
+ self .attribute_values [self ._discriminator ] = self .__class__ # type: ignore
318
+
291
319
def _set_defaults (self , _user_instantiated : bool = True ) -> None :
292
320
"""
293
321
Sets and fields that provide a default value
@@ -336,6 +364,7 @@ def _deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
336
364
Sets attributes sent back from DynamoDB on this object
337
365
"""
338
366
self .attribute_values = {}
367
+ self ._set_discriminator ()
339
368
self ._set_defaults (_user_instantiated = False )
340
369
for name , attr in self .get_attributes ().items ():
341
370
attribute_value = attribute_values .get (attr .attr_name )
@@ -352,6 +381,47 @@ def __ne__(self, other: Any) -> bool:
352
381
return self is not other
353
382
354
383
384
+ class DiscriminatorAttribute (Attribute [type ]):
385
+ attr_type = STRING
386
+
387
+ def __init__ (self , attr_name : Optional [str ] = None ) -> None :
388
+ super ().__init__ (attr_name = attr_name )
389
+ self ._class_map : Dict [type , Any ] = {}
390
+ self ._discriminator_map : Dict [Any , type ] = {}
391
+
392
+ def register_class (self , cls : type , discriminator : Any ):
393
+ discriminator = discriminator (cls ) if callable (discriminator ) else discriminator
394
+ current_class = self ._discriminator_map .get (discriminator )
395
+ if current_class and current_class != cls :
396
+ raise ValueError ("The discriminator value '{}' is already assigned to a class: {}" .format (
397
+ discriminator , current_class .__name__ ))
398
+
399
+ if cls not in self ._class_map :
400
+ self ._class_map [cls ] = discriminator
401
+
402
+ self ._discriminator_map [discriminator ] = cls
403
+
404
+ def get_discriminator (self , cls : type ) -> Optional [Any ]:
405
+ return self ._class_map .get (cls )
406
+
407
+ def __set__ (self , instance : Any , value : Optional [type ]) -> None :
408
+ raise TypeError ("'{}' object does not support item assignment" .format (self .__class__ .__name__ ))
409
+
410
+ def serialize (self , value ):
411
+ """
412
+ Returns the discriminator value corresponding to the given class.
413
+ """
414
+ return self ._class_map [value ]
415
+
416
+ def deserialize (self , value ):
417
+ """
418
+ Returns the class corresponding to the given discriminator value.
419
+ """
420
+ if value not in self ._discriminator_map :
421
+ raise ValueError ("Unknown discriminator value: {}" .format (value ))
422
+ return self ._discriminator_map [value ]
423
+
424
+
355
425
class BinaryAttribute (Attribute [bytes ]):
356
426
"""
357
427
A binary attribute
@@ -861,7 +931,14 @@ def deserialize(self, values):
861
931
"""
862
932
if not self .is_raw ():
863
933
# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
864
- instance = type (self )()
934
+ cls = type (self )
935
+ discriminator_attr = cls ._get_discriminator_attribute ()
936
+ if discriminator_attr :
937
+ discriminator_attribute_value = values .pop (discriminator_attr .attr_name , None )
938
+ if discriminator_attribute_value :
939
+ discriminator_value = discriminator_attr .get_value (discriminator_attribute_value )
940
+ cls = discriminator_attr .deserialize (discriminator_value )
941
+ instance = cls ()
865
942
instance ._deserialize (values )
866
943
return instance
867
944
0 commit comments