Skip to content

Commit a365bc8

Browse files
authored
Fix item refresh when using model discriminators. Fixes #879 (#880)
1 parent f680242 commit a365bc8

File tree

4 files changed

+26
-13
lines changed

4 files changed

+26
-13
lines changed

docs/release_notes.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
Release Notes
22
=============
33

4-
v5.0.0b3
4+
v5.0.0b4
55
-------------------
66

77
:date: 2020-xx-xx

pynamodb/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77
"""
88
__author__ = 'Jharrod LaFon'
99
__license__ = 'MIT'
10-
__version__ = '5.0.0b3'
10+
__version__ = '5.0.0b4'

pynamodb/attributes.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -312,7 +312,7 @@ def _get_discriminator_attribute(cls) -> Optional['DiscriminatorAttribute']:
312312
def _set_discriminator(self) -> None:
313313
discriminator_attr = self._get_discriminator_attribute()
314314
if discriminator_attr and discriminator_attr.get_discriminator(self.__class__) is not None:
315-
self.attribute_values[self._discriminator] = self.__class__ # type: ignore
315+
setattr(self, self._discriminator, self.__class__) # type: ignore
316316

317317
def _set_defaults(self, _user_instantiated: bool = True) -> None:
318318
"""
@@ -371,18 +371,22 @@ def deserialize(self, attribute_values: Dict[str, Dict[str, Any]]) -> None:
371371
setattr(self, name, value)
372372

373373
@classmethod
374-
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
374+
def _get_discriminator_class(cls, attribute_values: Dict[str, Dict[str, Any]]) -> Optional[Type]:
375375
discriminator_attr = cls._get_discriminator_attribute()
376376
if discriminator_attr:
377-
discriminator_attribute_value = attribute_values.pop(discriminator_attr.attr_name, None)
377+
discriminator_attribute_value = attribute_values.get(discriminator_attr.attr_name, None)
378378
if discriminator_attribute_value:
379379
discriminator_value = discriminator_attr.get_value(discriminator_attribute_value)
380-
stored_cls = discriminator_attr.deserialize(discriminator_value)
381-
if not issubclass(stored_cls, cls):
382-
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
383-
cls.__name__, stored_cls.__name__))
384-
cls = stored_cls
385-
instance = cls(_user_instantiated=False)
380+
return discriminator_attr.deserialize(discriminator_value)
381+
return None
382+
383+
@classmethod
384+
def _instantiate(cls: Type[_ACT], attribute_values: Dict[str, Dict[str, Any]]) -> _ACT:
385+
stored_cls = cls._get_discriminator_class(attribute_values)
386+
if stored_cls and not issubclass(stored_cls, cls):
387+
raise ValueError("Cannot instantiate a {} from the returned class: {}".format(
388+
cls.__name__, stored_cls.__name__))
389+
instance = (stored_cls or cls)(_user_instantiated=False)
386390
AttributeContainer.deserialize(instance, attribute_values)
387391
return instance
388392

@@ -422,7 +426,9 @@ def get_discriminator(self, cls: type) -> Optional[Any]:
422426
return self._class_map.get(cls)
423427

424428
def __set__(self, instance: Any, value: Optional[type]) -> None:
425-
raise TypeError("'{}' object does not support item assignment".format(self.__class__.__name__))
429+
if type(instance) != value:
430+
raise ValueError("The discriminator attribute must be set to the instance type: {}".format(type(instance)))
431+
super().__set__(instance, value)
426432

427433
def serialize(self, value):
428434
"""

pynamodb/models.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,11 @@ def update(self, actions: Sequence[Action], condition: Optional[Condition] = Non
424424
kwargs.update(actions=actions)
425425

426426
data = self._get_connection().update_item(*args, **kwargs)
427-
self.deserialize(data[ATTRIBUTES])
427+
item_data = data[ATTRIBUTES]
428+
stored_cls = self._get_discriminator_class(item_data)
429+
if stored_cls and stored_cls != type(self):
430+
raise ValueError("Cannot update this item from the returned class: {}".format(stored_cls.__name__))
431+
self.deserialize(item_data)
428432
return data
429433

430434
def save(self, condition: Optional[Condition] = None) -> Dict[str, Any]:
@@ -453,6 +457,9 @@ def refresh(self, consistent_read: bool = False) -> None:
453457
item_data = attrs.get(ITEM, None)
454458
if item_data is None:
455459
raise self.DoesNotExist("This item does not exist in the table.")
460+
stored_cls = self._get_discriminator_class(item_data)
461+
if stored_cls and stored_cls != type(self):
462+
raise ValueError("Cannot refresh this item from the returned class: {}".format(stored_cls.__name__))
456463
self.deserialize(item_data)
457464

458465
def get_operation_kwargs_from_instance(

0 commit comments

Comments
 (0)