Skip to content

Commit 51ceec0

Browse files
Jean-Charles BERTINgarrettheel
authored andcommitted
Ensure correct serialize is called for List and Map attributes (#286)
1 parent 2d815f6 commit 51ceec0

File tree

3 files changed

+104
-9
lines changed

3 files changed

+104
-9
lines changed

pynamodb/attributes.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,10 +487,13 @@ def serialize(self, values):
487487
rval = {}
488488
for k in values:
489489
v = values[k]
490-
attr_class = _get_class_for_serialize(v)
491-
attr_key = _get_key_for_serialize(v)
490+
attr_class = self._get_serialize_class(k, v)
492491
if attr_class is None:
493492
continue
493+
if attr_class.attr_type:
494+
attr_key = ATTR_TYPE_MAP[attr_class.attr_type]
495+
else:
496+
attr_key = _get_key_for_serialize(v)
494497

495498
# If this is a subclassed MapAttribute, there may be an alternate attr name
496499
attr = self._get_attributes().get(k)
@@ -531,6 +534,12 @@ def as_dict(self):
531534
result[key] = value.as_dict() if isinstance(value, MapAttribute) else value
532535
return result
533536

537+
@classmethod
538+
def _get_serialize_class(cls, key, value):
539+
if not cls.is_raw():
540+
return cls._get_attributes().get(key)
541+
return _get_class_for_serialize(value)
542+
534543
@classmethod
535544
def _get_deserialize_class(cls, key, value):
536545
if not cls.is_raw():
@@ -595,8 +604,13 @@ def serialize(self, values):
595604
"""
596605
rval = []
597606
for v in values:
598-
attr_class = _get_class_for_serialize(v)
599-
attr_key = _get_key_for_serialize(v)
607+
attr_class = (self.element_type()
608+
if self.element_type
609+
else _get_class_for_serialize(v))
610+
if attr_class.attr_type:
611+
attr_key = ATTR_TYPE_MAP[attr_class.attr_type]
612+
else:
613+
attr_key = _get_key_for_serialize(v)
600614
rval.append({attr_key: attr_class.serialize(v)})
601615
return rval
602616

pynamodb/tests/test_attributes.py

Lines changed: 85 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from dateutil.parser import parse
1111
from dateutil.tz import tzutc
1212

13-
from mock import patch
13+
from mock import patch, Mock, call
1414

1515
from pynamodb.compat import CompatTestCase as TestCase
1616
from pynamodb.constants import UTC, DATETIME_FORMAT
@@ -19,7 +19,7 @@
1919
from pynamodb.attributes import (
2020
BinarySetAttribute, BinaryAttribute, NumberSetAttribute, NumberAttribute,
2121
UnicodeAttribute, UnicodeSetAttribute, UTCDateTimeAttribute, BooleanAttribute, LegacyBooleanAttribute,
22-
MapAttribute, ListAttribute,
22+
MapAttribute, ListAttribute, Attribute,
2323
JSONAttribute, DEFAULT_ENCODING, NUMBER, STRING, STRING_SET, NUMBER_SET, BINARY_SET,
2424
BINARY, MAP, LIST, BOOLEAN, _get_value_for_deserialize)
2525

@@ -395,7 +395,7 @@ def test_unicode_set_deserialize(self):
395395
value
396396
)
397397

398-
def test_unicode_set_deserialize(self):
398+
def test_unicode_set_deserialize_old_way(self):
399399
"""
400400
UnicodeSetAttribute.deserialize old way
401401
"""
@@ -655,6 +655,34 @@ class SomeModel(Model):
655655
self.assertEqual(json.dumps({'map_attr': {'foo': 'bar'}}),
656656
json.dumps(item.typed_map.as_dict()))
657657

658+
def test_json_serialize(self):
659+
class JSONMapAttribute(MapAttribute):
660+
arbitrary_data = JSONAttribute()
661+
662+
def __eq__(self, other):
663+
return self.arbitrary_data == other.arbitrary_data
664+
665+
item = {'foo': 'bar', 'bool': True, 'number': 3.141}
666+
json_map = JSONMapAttribute(arbitrary_data=item)
667+
serialized = json_map.serialize(json_map)
668+
deserialized = json_map.deserialize(serialized)
669+
self.assertTrue(isinstance(deserialized, JSONMapAttribute))
670+
self.assertEqual(deserialized, json_map)
671+
self.assertEqual(deserialized.arbitrary_data, item)
672+
673+
def test_serialize_datetime(self):
674+
class CustomMapAttribute(MapAttribute):
675+
date_attr = UTCDateTimeAttribute()
676+
677+
cm = CustomMapAttribute(date_attr=datetime(2017, 1, 1))
678+
serialized_datetime = cm.serialize(cm)
679+
expected_serialized_value = {
680+
'date_attr': {
681+
'S': u'2017-01-01T00:00:00.000000+0000'
682+
}
683+
}
684+
self.assertEquals(serialized_datetime, expected_serialized_value)
685+
658686

659687
class ValueDeserializeTestCase(TestCase):
660688
def test__get_value_for_deserialize(self):
@@ -728,8 +756,8 @@ def __lt__(self, other):
728756
return self.name < other.name
729757

730758
def __eq__(self, other):
731-
return self.name == other.name and \
732-
self.age == other.age
759+
return (self.name == other.name and
760+
self.age == other.age)
733761

734762
person1 = Person()
735763
person1.name = 'john'
@@ -738,13 +766,65 @@ def __eq__(self, other):
738766
person2 = Person()
739767
person2.name = 'Dana'
740768
person2.age = 41
769+
741770
inp = [person1, person2]
742771

743772
list_attribute = ListAttribute(default=[], of=Person)
744773
serialized = list_attribute.serialize(inp)
745774
deserialized = list_attribute.deserialize(serialized)
746775
self.assertEqual(sorted(deserialized), sorted(inp))
747776

777+
def test_list_of_map_with_of_and_custom_attribute(self):
778+
779+
# Create a couple of mock functions to use
780+
# to test that the CustomAttribute serialize/deserialize are called
781+
serialize_mock = Mock()
782+
deserialize_mock = Mock()
783+
784+
class CustomAttribute(Attribute):
785+
attr_type = STRING
786+
787+
def serialize(self, value):
788+
serialize_mock(value)
789+
return value.upper()
790+
791+
def deserialize(self, value):
792+
deserialize_mock(value)
793+
return value.lower()
794+
795+
class CustomMapAttribute(MapAttribute):
796+
custom = CustomAttribute()
797+
798+
def __lt__(self, other):
799+
return self.custom < other.custom
800+
801+
def __eq__(self, other):
802+
return self.custom == other.custom
803+
804+
attribute1 = CustomMapAttribute()
805+
attribute1.custom = 'test-value1'
806+
807+
attribute2 = CustomMapAttribute()
808+
attribute2.custom = 'test-value2'
809+
810+
inp = [attribute1, attribute2]
811+
812+
list_attribute = ListAttribute(default=[], of=CustomMapAttribute)
813+
serialized = list_attribute.serialize(inp)
814+
deserialized = list_attribute.deserialize(serialized)
815+
self.assertEqual(sorted(deserialized), sorted(inp))
816+
817+
# Confirm that the the serialize/deserialize are called
818+
# with the expected values
819+
serialize_mock.assert_has_calls([
820+
call('test-value1'),
821+
call('test-value2'),
822+
])
823+
deserialize_mock.assert_has_calls([
824+
call('TEST-VALUE1'),
825+
call('TEST-VALUE2'),
826+
])
827+
748828
def test_list_of_unicode_with_of(self):
749829
with self.assertRaises(ValueError):
750830
ListAttribute(default=[], of=UnicodeAttribute)

requirements-dev.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
-rrequirements.txt
22
coverage==3.7.1
3+
mock==2.0.0
34
pytest==3.0.7
45
pytest-cov==2.4.0
56
python-coveralls==2.5.0

0 commit comments

Comments
 (0)