Skip to content

Commit 2e2a37c

Browse files
authored
Support model polymorphism using discriminators. Fixes #247, #328 (#864)
1 parent 5c8862e commit 2e2a37c

File tree

8 files changed

+198
-39
lines changed

8 files changed

+198
-39
lines changed

docs/polymorphism.rst

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ Polymorphism
66
PynamoDB supports polymorphism through the use of discriminators.
77

88
A discriminator is a value that is written to DynamoDB that identifies the python class being stored.
9-
(Note: currently discriminators are only supported on MapAttribute subclasses; support for model subclasses coming soon.)
109

1110
Discriminator Attributes
1211
^^^^^^^^^^^^^^^^^^^^^^^^
@@ -49,3 +48,32 @@ A class can also be registered manually:
4948
Discriminator values are written to DynamoDB.
5049
Changing the value after items have been saved to the database can result in deserialization failures.
5150
In order to read items with an old discriminator value, the old value must be manually registered.
51+
52+
53+
Model Discriminators
54+
^^^^^^^^^^^^^^^^^^^^
55+
56+
Model classes also support polymorphism through the use of discriminators.
57+
(Note: currently discriminator attributes cannot be used as the hash or range key of a table.)
58+
59+
.. code-block:: python
60+
61+
class ParentModel(Model):
62+
class Meta:
63+
table_name = 'polymorphic_table'
64+
id = UnicodeAttribute(hash_key=True)
65+
cls = DiscriminatorAttribute()
66+
67+
class FooModel(ParentModel, discriminator='Foo'):
68+
foo = UnicodeAttribute()
69+
70+
class BarModel(ParentModel, discriminator='Bar'):
71+
bar = UnicodeAttribute()
72+
73+
BarModel(id='Hello', bar='World!').serialize()
74+
# {'id': {'S': 'Hello'}, 'cls': {'S': 'Bar'}, 'bar': {'S': 'World!'}}
75+
.. note::
76+
77+
Read operations that are performed on a class that has a discriminator value are slightly modified to ensure that only instances of the class are returned.
78+
Query and scan operations transparently add a filter condition to ensure that only items with a matching discriminator value are returned.
79+
Get and batch get operations will raise a ``ValueError`` if the returned item(s) are not a subclass of the model being read.

docs/release_notes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Release Notes
22
=============
33

4-
v5.0.0b1
4+
v5.0.0b2
55
-------------------
66

77
:date: 2020-xx-xx

pynamodb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
"""
88
__author__ = 'Jharrod LaFon'
99
__license__ = 'MIT'
10-
__version__ = '5.0.0b1'
10+
__version__ = '5.0.0b2'

pynamodb/attributes.py

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
_KT = TypeVar('_KT', bound=str)
4242
_VT = TypeVar('_VT')
4343
_MT = TypeVar('_MT', bound='MapAttribute')
44+
_ACT = TypeVar('_ACT', bound = 'AttributeContainer')
4445

4546
_A = TypeVar('_A', bound='Attribute')
4647

@@ -259,9 +260,6 @@ def _initialize_attributes(cls, discriminator_value):
259260
raise ValueError("{} has more than one discriminator attribute: {}".format(
260261
cls.__name__, ", ".join(discriminators)))
261262
cls._discriminator = discriminators[0] if discriminators else None
262-
# TODO(jpinner) add support for model polymorphism
263-
if cls._discriminator and not issubclass(cls, MapAttribute):
264-
raise NotImplementedError("Discriminators are not yet supported in model classes.")
265263
if discriminator_value is not None:
266264
if not cls._discriminator:
267265
raise ValueError("{} does not have a discriminator attribute".format(cls.__name__))
@@ -372,6 +370,22 @@ def deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
372370
value = attr.deserialize(attr.get_value(attribute_value))
373371
setattr(self, name, value)
374372

373+
@classmethod
374+
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
375+
discriminator_attr = cls._get_discriminator_attribute()
376+
if discriminator_attr:
377+
discriminator_attribute_value = attribute_values.pop(discriminator_attr.attr_name, None)
378+
if discriminator_attribute_value:
379+
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
380+
stored_cls = discriminator_attr.deserialize(discriminator_value)
381+
if not issubclass(stored_cls, cls):
382+
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
383+
cls.__name__, stored_cls.__name__))
384+
cls = stored_cls
385+
instance = cls(_user_instantiated=False)
386+
AttributeContainer.deserialize(instance, attribute_values)
387+
return instance
388+
375389
def __eq__(self, other: Any) -> bool:
376390
# This is required so that MapAttribute can call this method.
377391
return self is other
@@ -940,16 +954,7 @@ def deserialize(self, values):
940954
"""
941955
if not self.is_raw():
942956
# If this is a subclass of a MapAttribute (i.e typed), instantiate an instance
943-
cls = type(self)
944-
discriminator_attr = cls._get_discriminator_attribute()
945-
if discriminator_attr:
946-
discriminator_attribute_value = values.pop(discriminator_attr.attr_name, None)
947-
if discriminator_attribute_value:
948-
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
949-
cls = discriminator_attr.deserialize(discriminator_value)
950-
instance = cls()
951-
AttributeContainer.deserialize(instance, values)
952-
return instance
957+
return self._instantiate(values)
953958

954959
return {
955960
k: DESERIALIZE_CLASS_MAP[attr_type].deserialize(attr_value)

pynamodb/models.py

Lines changed: 51 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,28 @@
55
import time
66
import logging
77
import warnings
8+
import sys
89
from inspect import getmembers
9-
from typing import Any, Dict, Generic, Iterable, Iterator, List, Optional, Sequence, Mapping, Type, TypeVar, Text, \
10-
Tuple, Union, cast
10+
from typing import Any
11+
from typing import Dict
12+
from typing import Generic
13+
from typing import Iterable
14+
from typing import Iterator
15+
from typing import List
16+
from typing import Mapping
17+
from typing import Optional
18+
from typing import Sequence
19+
from typing import Text
20+
from typing import Tuple
21+
from typing import Type
22+
from typing import TypeVar
23+
from typing import Union
24+
from typing import cast
25+
26+
if sys.version_info >= (3, 8):
27+
from typing import Protocol
28+
else:
29+
from typing_extensions import Protocol
1130

1231
from pynamodb.expressions.update import Action
1332
from pynamodb.exceptions import DoesNotExist, TableDoesNotExist, TableError, InvalidStateError, PutError
@@ -151,7 +170,7 @@ def commit(self) -> None:
151170
unprocessed_items = data.get(UNPROCESSED_ITEMS, {}).get(self.model.Meta.table_name)
152171

153172

154-
class MetaModel(AttributeContainerMeta):
173+
class MetaProtocol(Protocol):
155174
table_name: str
156175
read_capacity_units: Optional[int]
157176
write_capacity_units: Optional[int]
@@ -169,14 +188,17 @@ class MetaModel(AttributeContainerMeta):
169188
billing_mode: Optional[str]
170189
stream_view_type: Optional[str]
171190

191+
192+
class MetaModel(AttributeContainerMeta):
172193
"""
173194
Model meta class
174-
175-
This class is just here so that index queries have nice syntax.
176-
Model.index.query()
177195
"""
178-
def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
179-
super().__init__(name, bases, attrs)
196+
def __new__(cls, name, bases, namespace, discriminator=None):
197+
# Defined so that the discriminator can be set in the class definition.
198+
return super().__new__(cls, name, bases, namespace)
199+
200+
def __init__(self, name, bases, namespace, discriminator=None) -> None:
201+
super().__init__(name, bases, namespace, discriminator)
180202
cls = cast(Type['Model'], self)
181203
for attr_name, attribute in cls.get_attributes().items():
182204
if attribute.is_hash_key:
@@ -200,8 +222,8 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
200222
raise ValueError("{} has more than one TTL attribute: {}".format(
201223
cls.__name__, ", ".join(ttl_attr_names)))
202224

203-
if isinstance(attrs, dict):
204-
for attr_name, attr_obj in attrs.items():
225+
if isinstance(namespace, dict):
226+
for attr_name, attr_obj in namespace.items():
205227
if attr_name == META_CLASS_NAME:
206228
if not hasattr(attr_obj, REGION):
207229
setattr(attr_obj, REGION, get_settings_value('region'))
@@ -234,9 +256,9 @@ def __init__(self, name: str, bases: Any, attrs: Dict[str, Any]) -> None:
234256

235257
# create a custom Model.DoesNotExist derived from pynamodb.exceptions.DoesNotExist,
236258
# so that "except Model.DoesNotExist:" would not catch other models' exceptions
237-
if 'DoesNotExist' not in attrs:
259+
if 'DoesNotExist' not in namespace:
238260
exception_attrs = {
239-
'__module__': attrs.get('__module__'),
261+
'__module__': namespace.get('__module__'),
240262
'__qualname__': f'{cls.__qualname__}.{"DoesNotExist"}',
241263
}
242264
cls.DoesNotExist = type('DoesNotExist', (DoesNotExist, ), exception_attrs)
@@ -260,7 +282,7 @@ class Model(AttributeContainer, metaclass=MetaModel):
260282
DoesNotExist: Type[DoesNotExist] = DoesNotExist
261283
_version_attribute_name: Optional[str] = None
262284

263-
Meta: MetaModel
285+
Meta: MetaProtocol
264286

265287
def __init__(
266288
self,
@@ -520,9 +542,7 @@ def from_raw_data(cls: Type[_T], data: Dict[str, Any]) -> _T:
520542
if data is None:
521543
raise ValueError("Received no data to construct object")
522544

523-
model = cls(_user_instantiated=False)
524-
model.deserialize(data)
525-
return model
545+
return cls._instantiate(data)
526546

527547
@classmethod
528548
def count(
@@ -556,6 +576,11 @@ def count(
556576
else:
557577
hash_key = cls._serialize_keys(hash_key)[0]
558578

579+
# If this class has a discriminator value, filter the query to only return instances of this class.
580+
discriminator_attr = cls._get_discriminator_attribute()
581+
if discriminator_attr and discriminator_attr.get_discriminator(cls):
582+
filter_condition &= discriminator_attr == cls
583+
559584
query_args = (hash_key,)
560585
query_kwargs = dict(
561586
range_key_condition=range_key_condition,
@@ -616,6 +641,11 @@ def query(
616641
else:
617642
hash_key = cls._serialize_keys(hash_key)[0]
618643

644+
# If this class has a discriminator value, filter the query to only return instances of this class.
645+
discriminator_attr = cls._get_discriminator_attribute()
646+
if discriminator_attr and discriminator_attr.get_discriminator(cls):
647+
filter_condition &= discriminator_attr == cls
648+
619649
if page_size is None:
620650
page_size = limit
621651

@@ -668,6 +698,11 @@ def scan(
668698
:param rate_limit: If set then consumed capacity will be limited to this amount per second
669699
:param attributes_to_get: If set, specifies the properties to include in the projection expression
670700
"""
701+
# If this class has a discriminator value, filter the scan to only return instances of this class.
702+
discriminator_attr = cls._get_discriminator_attribute()
703+
if discriminator_attr and discriminator_attr.get_discriminator(cls):
704+
filter_condition &= discriminator_attr == cls
705+
671706
if page_size is None:
672707
page_size = limit
673708

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
install_requires = [
55
'botocore>=1.12.54',
6+
'typing-extensions>=3.7; python_version<"3.8"'
67
]
78

89
setup(

tests/test_discriminator.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,18 @@ class RenamedValue(TypedValue, discriminator='custom_name'):
2828
value = UnicodeAttribute()
2929

3030

31-
class DiscriminatorTestModel(Model):
31+
class DiscriminatorTestModel(Model, discriminator='Parent'):
3232
class Meta:
3333
host = 'http://localhost:8000'
3434
table_name = 'test'
3535
hash_key = UnicodeAttribute(hash_key=True)
3636
value = TypedValue()
3737
values = ListAttribute(of=TypedValue)
38+
type = DiscriminatorAttribute()
39+
40+
41+
class ChildModel(DiscriminatorTestModel, discriminator='Child'):
42+
value = UnicodeAttribute()
3843

3944

4045
class TestDiscriminatorAttribute:
@@ -46,6 +51,7 @@ def test_serialize(self):
4651
dtm.values = [NumberValue(name='bar', value=5), RenamedValue(name='baz', value='World')]
4752
assert dtm.serialize() == {
4853
'hash_key': {'S': 'foo'},
54+
'type': {'S': 'Parent'},
4955
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
5056
'values': {'L': [
5157
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
@@ -56,6 +62,7 @@ def test_serialize(self):
5662
def test_deserialize(self):
5763
item = {
5864
'hash_key': {'S': 'foo'},
65+
'type': {'S': 'Parent'},
5966
'value': {'M': {'cls': {'S': 'StringValue'}, 'name': {'S': 'foo'}, 'value': {'S': 'Hello'}}},
6067
'values': {'L': [
6168
{'M': {'cls': {'S': 'NumberValue'}, 'name': {'S': 'bar'}, 'value': {'N': '5'}}},
@@ -96,8 +103,28 @@ def test_multiple_discriminator_classes(self):
96103
class RenamedValue2(TypedValue, discriminator='custom_name'):
97104
pass
98105

99-
def test_model(self):
100-
with pytest.raises(NotImplementedError):
101-
class DiscriminatedModel(Model):
102-
hash_key = UnicodeAttribute(hash_key=True)
103-
_cls = DiscriminatorAttribute()
106+
class TestDiscriminatorModel:
107+
108+
def test_serialize(self):
109+
cm = ChildModel()
110+
cm.hash_key = 'foo'
111+
cm.value = 'bar'
112+
cm.values = []
113+
assert cm.serialize() == {
114+
'hash_key': {'S': 'foo'},
115+
'type': {'S': 'Child'},
116+
'value': {'S': 'bar'},
117+
'values': {'L': []}
118+
}
119+
120+
def test_deserialize(self):
121+
item = {
122+
'hash_key': {'S': 'foo'},
123+
'type': {'S': 'Child'},
124+
'value': {'S': 'bar'},
125+
'values': {'L': []}
126+
}
127+
cm = DiscriminatorTestModel.from_raw_data(item)
128+
assert isinstance(cm, ChildModel)
129+
assert cm.hash_key == 'foo'
130+
assert cm.value == 'bar'

0 commit comments

Comments
 (0)