Skip to content

Commit 09b599d

Browse files
authored
Support MapAttribute polymorphism using discriminators. (#836)
1 parent ed05984 commit 09b599d

File tree

5 files changed

+252
-9
lines changed

5 files changed

+252
-9
lines changed

docs/index.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,6 @@ Features
1111
========
1212

1313
* Python 3 support
14-
* Python 2 support
1514
* Support for Unicode, Binary, JSON, Number, Set, and UTC Datetime attributes
1615
* Support for DynamoDB Local
1716
* Support for all of the DynamoDB API
@@ -32,6 +31,7 @@ Topics
3231
batch
3332
updates
3433
conditional
34+
polymorphism
3535
attributes
3636
transaction
3737
optimistic_locking

docs/polymorphism.rst

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
Polymorphism
2+
============
3+
4+
PynamoDB supports polymorphism through the use of discriminators.
5+
6+
A discriminator is a value that is written to DynamoDB that identifies the python class being stored.
7+
(Note: currently discriminators are only supported on MapAttribute subclasses; support for model subclasses coming soon.)
8+
9+
Discriminator Attributes
10+
^^^^^^^^^^^^^^^^^^^^^^^^
11+
12+
The discriminator value is stored using a special attribute, the DiscriminatorAttribute.
13+
Only a single DiscriminatorAttribute can be defined on a class.
14+
15+
The discriminator value can be assigned to a class as part of the definition:
16+
17+
.. code-block:: python
18+
19+
class ParentClass(MapAttribute):
20+
cls = DiscriminatorAttribute()
21+
22+
class ChildClass(ParentClass, discriminator='child'):
23+
pass
24+
25+
Declaring the discriminator value as part of the class definition will automatically register the class with the discriminator attribute.
26+
A class can also be registered manually:
27+
28+
.. code-block:: python
29+
30+
class ParentClass(MapAttribute):
31+
cls = DiscriminatorAttribute()
32+
33+
class ChildClass(ParentClass):
34+
pass
35+
36+
ParentClass._cls.register_class(ChildClass, 'child')
37+
38+
.. note::
39+
40+
A class may be registered with a discriminator attribute multiple times.
41+
Only the first registered value is used during serialization;
42+
however, any registered value can be used to deserialize the class.
43+
This behavior is intended to facilitate migrations if discriminator values must be changed.
44+
45+
.. warning::
46+
47+
Discriminator values are written to DynamoDB.
48+
Changing the value after items have been saved to the database can result in deserialization failures.
49+
In order to read items with an old discriminator value, the old value must be manually registered.

docs/release_notes.rst

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
11
Release Notes
22
=============
33

4-
v5.0 (unreleased)
5-
-----------------
4+
v5.0.0 (unreleased)
5+
-------------------
66

7+
:date: 2020-xx-xx
8+
9+
This is major release and contains breaking changes. Please read the notes below carefully.
10+
11+
**Polymorphism**
12+
13+
This release introduces polymorphism support via ``DiscriminatorAttribute``.
14+
Discriminator values are written to DynamoDB and used during deserialization to instantiate the desired class.
15+
16+
Other changes in this release:
17+
18+
* Python 2 is no longer supported. Python 3.6 or greater is now required.
719
* ``Model.query`` no longer demotes invalid range key conditions to be filter conditions to avoid surprising behaviors:
820
where what's intended to be a cheap and fast condition ends up being expensive and slow. Since filter conditions
921
cannot contain range keys, this had limited utility to begin with, and would sometimes cause confusing

pynamodb/attributes.py

Lines changed: 83 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from dateutil.tz import tzutc
1414
from inspect import getfullargspec
1515
from inspect import getmembers
16-
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, TypeVar, Type, Union, Set, overload
16+
from typing import Any, Callable, Dict, Generic, List, Mapping, Optional, TypeVar, Type, Union, Set, cast, overload
1717
from typing import TYPE_CHECKING
1818

1919
from pynamodb._compat import GenericMeta
@@ -218,12 +218,16 @@ def delete(self, *values: Any) -> 'DeleteAction':
218218

219219
class AttributeContainerMeta(GenericMeta):
220220

221-
def __init__(self, name, bases, attrs, *args, **kwargs):
222-
super().__init__(name, bases, attrs, *args, **kwargs) # type: ignore
223-
AttributeContainerMeta._initialize_attributes(self)
221+
def __new__(cls, name, bases, namespace, discriminator=None):
222+
# Defined so that the discriminator can be set in the class definition.
223+
return super().__new__(cls, name, bases, namespace)
224+
225+
def __init__(self, name, bases, namespace, discriminator=None):
226+
super().__init__(name, bases, namespace)
227+
AttributeContainerMeta._initialize_attributes(self, discriminator)
224228

225229
@staticmethod
226-
def _initialize_attributes(cls):
230+
def _initialize_attributes(cls, discriminator_value):
227231
"""
228232
Initialize attributes on the class.
229233
"""
@@ -249,6 +253,20 @@ def _initialize_attributes(cls):
249253
# Prepend the `attr_path` lists with the dynamo attribute name.
250254
attribute._update_attribute_paths(attribute.attr_name)
251255

256+
# Register the class with the discriminator if necessary.
257+
discriminators = [name for name, attr in cls._attributes.items() if isinstance(attr, DiscriminatorAttribute)]
258+
if len(discriminators) > 1:
259+
raise ValueError("{} has more than one discriminator attribute: {}".format(
260+
cls.__name__, ", ".join(discriminators)))
261+
cls._discriminator = discriminators[0] if discriminators else None
262+
# TODO(jpinner) add support for model polymorphism
263+
if cls._discriminator and not issubclass(cls, MapAttribute):
264+
raise NotImplementedError("Discriminators are not yet supported in model classes.")
265+
if discriminator_value is not None:
266+
if not cls._discriminator:
267+
raise ValueError("{} does not have a discriminator attribute".format(cls.__name__))
268+
cls._attributes[cls._discriminator].register_class(cls, discriminator_value)
269+
252270

253271
class AttributeContainer(metaclass=AttributeContainerMeta):
254272

@@ -259,6 +277,7 @@ def __init__(self, _user_instantiated: bool = True, **attributes: Attribute) ->
259277
# instances do not have any Attributes defined and instead use this dictionary to store their
260278
# collection of name-value pairs.
261279
self.attribute_values: Dict[str, Any] = {}
280+
self._set_discriminator()
262281
self._set_defaults(_user_instantiated=_user_instantiated)
263282
self._set_attributes(**attributes)
264283

@@ -288,6 +307,15 @@ def _dynamo_to_python_attr(cls, dynamo_key: str) -> str:
288307
"""
289308
return cls._dynamo_to_python_attrs.get(dynamo_key, dynamo_key) # type: ignore
290309

310+
@classmethod
311+
def _get_discriminator_attribute(cls) -> Optional['DiscriminatorAttribute']:
312+
return cls.get_attributes()[cls._discriminator] if cls._discriminator else None # type: ignore
313+
314+
def _set_discriminator(self) -> None:
315+
discriminator_attr = self._get_discriminator_attribute()
316+
if discriminator_attr and discriminator_attr.get_discriminator(self.__class__) is not None:
317+
self.attribute_values[self._discriminator] = self.__class__ # type: ignore
318+
291319
def _set_defaults(self, _user_instantiated: bool = True) -> None:
292320
"""
293321
Sets and fields that provide a default value
@@ -336,6 +364,7 @@ def _deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
336364
Sets attributes sent back from DynamoDB on this object
337365
"""
338366
self.attribute_values = {}
367+
self._set_discriminator()
339368
self._set_defaults(_user_instantiated=False)
340369
for name, attr in self.get_attributes().items():
341370
attribute_value = attribute_values.get(attr.attr_name)
@@ -352,6 +381,47 @@ def __ne__(self, other: Any) -> bool:
352381
return self is not other
353382

354383

384+
class DiscriminatorAttribute(Attribute[type]):
385+
attr_type = STRING
386+
387+
def __init__(self, attr_name: Optional[str] = None) -> None:
388+
super().__init__(attr_name=attr_name)
389+
self._class_map: Dict[type, Any] = {}
390+
self._discriminator_map: Dict[Any, type] = {}
391+
392+
def register_class(self, cls: type, discriminator: Any):
393+
discriminator = discriminator(cls) if callable(discriminator) else discriminator
394+
current_class = self._discriminator_map.get(discriminator)
395+
if current_class and current_class != cls:
396+
raise ValueError("The discriminator value '{}' is already assigned to a class: {}".format(
397+
discriminator, current_class.__name__))
398+
399+
if cls not in self._class_map:
400+
self._class_map[cls] = discriminator
401+
402+
self._discriminator_map[discriminator] = cls
403+
404+
def get_discriminator(self, cls: type) -> Optional[Any]:
405+
return self._class_map.get(cls)
406+
407+
def __set__(self, instance: Any, value: Optional[type]) -> None:
408+
raise TypeError("'{}' object does not support item assignment".format(self.__class__.__name__))
409+
410+
def serialize(self, value):
411+
"""
412+
Returns the discriminator value corresponding to the given class.
413+
"""
414+
return self._class_map[value]
415+
416+
def deserialize(self, value):
417+
"""
418+
Returns the class corresponding to the given discriminator value.
419+
"""
420+
if value not in self._discriminator_map:
421+
raise ValueError("Unknown discriminator value: {}".format(value))
422+
return self._discriminator_map[value]
423+
424+
355425
class BinaryAttribute(Attribute[bytes]):
356426
"""
357427
A binary attribute
@@ -861,7 +931,14 @@ def deserialize(self, values):
861931
"""
862932
if not self.is_raw():
863933
# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
864-
instance = type(self)()
934+
cls = type(self)
935+
discriminator_attr = cls._get_discriminator_attribute()
936+
if discriminator_attr:
937+
discriminator_attribute_value = values.pop(discriminator_attr.attr_name, None)
938+
if discriminator_attribute_value:
939+
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
940+
cls = discriminator_attr.deserialize(discriminator_value)
941+
instance = cls()
865942
instance._deserialize(values)
866943
return instance
867944

tests/test_discriminator.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import pytest
2+
3+
from pynamodb.attributes import DiscriminatorAttribute
4+
from pynamodb.attributes import ListAttribute
5+
from pynamodb.attributes import MapAttribute
6+
from pynamodb.attributes import NumberAttribute
7+
from pynamodb.attributes import UnicodeAttribute
8+
from pynamodb.models import Model
9+
10+
11+
class_name = lambda cls: cls.__name__
12+
13+
14+
class TypedValue(MapAttribute):
15+
_cls = DiscriminatorAttribute(attr_name = 'cls')
16+
name = UnicodeAttribute()
17+
18+
19+
class NumberValue(TypedValue, discriminator=class_name):
20+
value = NumberAttribute()
21+
22+
23+
class StringValue(TypedValue, discriminator=class_name):
24+
value = UnicodeAttribute()
25+
26+
27+
class RenamedValue(TypedValue, discriminator='custom_name'):
28+
value = UnicodeAttribute()
29+
30+
31+
class DiscriminatorTestModel(Model):
32+
class Meta:
33+
host = 'http://localhost:8000'
34+
table_name = 'test'
35+
hash_key = UnicodeAttribute(hash_key=True)
36+
value = TypedValue()
37+
values = ListAttribute(of=TypedValue)
38+
39+
40+
class TestDiscriminatorAttribute:
41+
42+
def test_serialize(self):
43+
dtm = DiscriminatorTestModel()
44+
dtm.hash_key = 'foo'
45+
dtm.value = StringValue(name='foo', value='Hello')
46+
dtm.values = [NumberValue(name='bar', value=5), RenamedValue(name='baz', value='World')]
47+
assert dtm._serialize() == {
48+
'HASH': 'foo',
49+
'attributes': {
50+
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
51+
'values': {'L': [
52+
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
53+
{'M': {'cls': {'S': 'custom_name'}, 'name': {'S': 'baz'}, 'value': {'S': 'World'}}}
54+
]}
55+
}
56+
}
57+
58+
def test_deserialize(self):
59+
item = {
60+
'hash_key': {'S': 'foo'},
61+
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
62+
'values': {'L': [
63+
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
64+
{'M': {'cls': {'S': 'custom_name'}, 'name': {'S': 'baz'}, 'value': {'S': 'World'}}}
65+
]}
66+
}
67+
dtm = DiscriminatorTestModel.from_raw_data(item)
68+
assert dtm.hash_key == 'foo'
69+
assert dtm.value.value == 'Hello'
70+
assert dtm.values[0].value == 5
71+
assert dtm.values[1].value == 'World'
72+
73+
def test_condition_expression(self):
74+
condition = DiscriminatorTestModel.value._cls == RenamedValue
75+
placeholder_names, expression_attribute_values = {}, {}
76+
expression = condition.serialize(placeholder_names, expression_attribute_values)
77+
assert expression == "#0.#1 = :0"
78+
assert placeholder_names == {'value': '#0', 'cls': '#1'}
79+
assert expression_attribute_values == {':0': {'S': 'custom_name'}}
80+
81+
def test_multiple_discriminator_values(self):
82+
class TestAttribute(MapAttribute, discriminator='new_value'):
83+
cls = DiscriminatorAttribute()
84+
85+
TestAttribute.cls.register_class(TestAttribute, 'old_value')
86+
87+
# ensure the first registered value is used during serialization
88+
assert TestAttribute.cls.get_discriminator(TestAttribute) == 'new_value'
89+
assert TestAttribute.cls.serialize(TestAttribute) == 'new_value'
90+
91+
# ensure the second registered value can be used to deserialize
92+
assert TestAttribute.cls.deserialize('old_value') == TestAttribute
93+
assert TestAttribute.cls.deserialize('new_value') == TestAttribute
94+
95+
def test_multiple_discriminator_classes(self):
96+
with pytest.raises(ValueError):
97+
# fail when attempting to register a class with an existing discriminator value
98+
class RenamedValue2(TypedValue, discriminator='custom_name'):
99+
pass
100+
101+
def test_model(self):
102+
with pytest.raises(NotImplementedError):
103+
class DiscriminatedModel(Model):
104+
hash_key = UnicodeAttribute(hash_key=True)
105+
_cls = DiscriminatorAttribute()

0 commit comments

Comments
 (0)