Skip to content

Commit ec56f95

Browse files
authored
Ensure queries and scans return model subclasses when using discriminators. (#873)
1 parent ce43a2f commit ec56f95

File tree

5 files changed

+22
-12
lines changed

5 files changed

+22
-12
lines changed

docs/release_notes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Release Notes
22
=============
33

4-
v5.0.0b2
4+
v5.0.0b3
55
-------------------
66

77
:date: 2020-xx-xx

pynamodb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
"""
88
__author__ = 'Jharrod LaFon'
99
__license__ = 'MIT'
10-
__version__ = '5.0.0b2'
10+
__version__ = '5.0.0b3'

pynamodb/attributes.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -415,6 +415,9 @@ def register_class(self, cls: type, discriminator: Any):
415415

416416
self._discriminator_map[discriminator] = cls
417417

418+
def get_registered_subclasses(self, cls: type) -> List[type]:
419+
return [k for k in self._class_map.keys() if issubclass(k, cls)]
420+
418421
def get_discriminator(self, cls: type) -> Optional[Any]:
419422
return self._class_map.get(cls)
420423

pynamodb/models.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -575,10 +575,10 @@ def count(
575575
else:
576576
hash_key = cls._serialize_keys(hash_key)[0]
577577

578-
# If this class has a discriminator value, filter the query to only return instances of this class.
578+
# If this class has a discriminator attribute, filter the query to only return instances of this class.
579579
discriminator_attr = cls._get_discriminator_attribute()
580-
if discriminator_attr and discriminator_attr.get_discriminator(cls):
581-
filter_condition &= discriminator_attr == cls
580+
if discriminator_attr:
581+
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))
582582

583583
query_args = (hash_key,)
584584
query_kwargs = dict(
@@ -640,10 +640,10 @@ def query(
640640
else:
641641
hash_key = cls._serialize_keys(hash_key)[0]
642642

643-
# If this class has a discriminator value, filter the query to only return instances of this class.
643+
# If this class has a discriminator attribute, filter the query to only return instances of this class.
644644
discriminator_attr = cls._get_discriminator_attribute()
645-
if discriminator_attr and discriminator_attr.get_discriminator(cls):
646-
filter_condition &= discriminator_attr == cls
645+
if discriminator_attr:
646+
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))
647647

648648
if page_size is None:
649649
page_size = limit
@@ -697,10 +697,10 @@ def scan(
697697
:param rate_limit: If set then consumed capacity will be limited to this amount per second
698698
:param attributes_to_get: If set, specifies the properties to include in the projection expression
699699
"""
700-
# If this class has a discriminator value, filter the scan to only return instances of this class.
700+
# If this class has a discriminator attribute, filter the scan to only return instances of this class.
701701
discriminator_attr = cls._get_discriminator_attribute()
702-
if discriminator_attr and discriminator_attr.get_discriminator(cls):
703-
filter_condition &= discriminator_attr == cls
702+
if discriminator_attr:
703+
filter_condition &= discriminator_attr.is_in(*discriminator_attr.get_registered_subclasses(cls))
704704

705705
if page_size is None:
706706
page_size = limit

tests/test_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1553,6 +1553,10 @@ class Meta:
15531553
class ChildModel(ParentModel, discriminator='Child'):
15541554
foo = UnicodeAttribute()
15551555

1556+
# register a model that subclasses Child to ensure queries return model subclasses
1557+
class GrandchildModel(ChildModel, discriminator='Grandchild'):
1558+
bar = UnicodeAttribute()
1559+
15561560
with patch(PATCH_METHOD) as req:
15571561
req.return_value = {
15581562
"Table": {
@@ -1588,7 +1592,7 @@ class ChildModel(ParentModel, discriminator='Child'):
15881592
pass
15891593
params = {
15901594
'KeyConditionExpression': '#0 = :0',
1591-
'FilterExpression': '#1 = :1',
1595+
'FilterExpression': '#1 IN (:1, :2)',
15921596
'ExpressionAttributeNames': {
15931597
'#0': 'id',
15941598
'#1': 'cls'
@@ -1599,6 +1603,9 @@ class ChildModel(ParentModel, discriminator='Child'):
15991603
},
16001604
':1': {
16011605
'S': u'Child'
1606+
},
1607+
':2': {
1608+
'S': u'Grandchild'
16021609
}
16031610
},
16041611
'ReturnConsumedCapacity': 'TOTAL',

0 commit comments

Comments
 (0)