Skip to content

Commit 61271e8

Browse files
authored
Improve typed list attribute expression creation. (#841)
1 parent 14f81c7 commit 61271e8

File tree

3 files changed

+55
-29
lines changed

3 files changed

+55
-29
lines changed

pynamodb/attributes.py

Lines changed: 42 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ def __gt__(self, other: Any) -> 'Comparison':
156156
def __ge__(self, other: Any) -> 'Comparison':
157157
return Path(self).__ge__(other)
158158

159-
def __getitem__(self, idx: int) -> Any:
160-
return Path(self).__getitem__(idx)
159+
def __getitem__(self, item: Union[int, str]) -> Path:
160+
return Path(self).__getitem__(item)
161161

162162
def between(self, lower: Any, upper: Any) -> 'Between':
163163
return Path(self).between(lower, upper)
@@ -838,12 +838,10 @@ def deserialize(self, values):
838838
instance._deserialize(values)
839839
return instance
840840

841-
deserialized_dict: Dict[str, Any] = dict()
842-
for k, v in values.items():
843-
attr_type, attr_value = next(iter(v.items()))
844-
attr_class = DESERIALIZE_CLASS_MAP[attr_type]
845-
deserialized_dict[k] = attr_class.deserialize(attr_value)
846-
return deserialized_dict
841+
return {
842+
k: DESERIALIZE_CLASS_MAP[attr_type].deserialize(attr_value)
843+
for k, v in values.items() for attr_type, attr_value in v.items()
844+
}
847845

848846
@classmethod
849847
def is_raw(cls):
@@ -900,7 +898,7 @@ def _fast_parse_utc_datestring(datestring):
900898

901899
class ListAttribute(Generic[_T], Attribute[List[_T]]):
902900
attr_type = LIST
903-
element_type: Any = None
901+
element_type: Optional[Type[Attribute]] = None
904902

905903
def __init__(
906904
self,
@@ -945,18 +943,41 @@ def deserialize(self, values):
945943
"""
946944
Decode from list of AttributeValue types.
947945
"""
948-
deserialized_lst = []
949-
for v in values:
950-
attr_type, attr_value = next(iter(v.items()))
951-
attr_class = self._get_deserialize_class(attr_type)
952-
if attr_class.attr_type != attr_type:
953-
raise ValueError("Cannot deserialize {} elements from type: {}".format(
954-
attr_class.__class__.__name__, attr_type))
955-
deserialized_lst.append(attr_class.deserialize(attr_value))
956-
return deserialized_lst
957-
958-
def __getitem__(self, idx: int) -> Path:
959-
# for typing only
946+
if self.element_type:
947+
element_attr = self.element_type()
948+
if isinstance(element_attr, MapAttribute):
949+
element_attr._make_attribute() # ensure attr_name exists
950+
deserialized_lst = []
951+
for idx, attribute_value in enumerate(values):
952+
value = None
953+
if NULL not in attribute_value:
954+
# set attr_name in case `get_value` raises an exception
955+
element_attr.attr_name = '{}[{}]'.format(self.attr_name, idx)
956+
value = element_attr.deserialize(element_attr.get_value(attribute_value))
957+
deserialized_lst.append(value)
958+
return deserialized_lst
959+
960+
return [
961+
DESERIALIZE_CLASS_MAP[attr_type].deserialize(attr_value)
962+
for v in values for attr_type, attr_value in v.items()
963+
]
964+
965+
def __getitem__(self, idx: int) -> Path: # type: ignore
966+
if not isinstance(idx, int):
967+
raise TypeError("list indices must be integers, not {}".format(type(idx).__name__))
968+
969+
if self.element_type:
970+
# If this instance is typed, return a properly configured attribute on list element access.
971+
element_attr = self.element_type()
972+
if isinstance(element_attr, MapAttribute):
973+
element_attr._make_attribute()
974+
element_attr.attr_path = list(self.attr_path) # copy the document path before indexing last element
975+
element_attr.attr_name = '{}[{}]'.format(element_attr.attr_name, idx)
976+
if isinstance(element_attr, MapAttribute):
977+
for path_segment in reversed(element_attr.attr_path):
978+
element_attr._update_attribute_paths(path_segment)
979+
return element_attr # type: ignore
980+
960981
return super().__getitem__(idx)
961982

962983
def _get_serialize_class(self, value):
@@ -968,11 +989,6 @@ def _get_serialize_class(self, value):
968989
return self.element_type()
969990
return SERIALIZE_CLASS_MAP[type(value)]
970991

971-
def _get_deserialize_class(self, attr_type):
972-
if self.element_type and attr_type != NULL:
973-
return self.element_type()
974-
return DESERIALIZE_CLASS_MAP[attr_type]
975-
976992

977993
DESERIALIZE_CLASS_MAP: Dict[str, Attribute] = {
978994
BINARY: BinaryAttribute(),

tests/test_attributes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -911,7 +911,7 @@ def test_list_type_error(self):
911911
with pytest.raises(ValueError):
912912
string_list_attribute.serialize([MapAttribute(foo='bar')])
913913

914-
with pytest.raises(ValueError):
914+
with pytest.raises(TypeError):
915915
string_list_attribute.deserialize([{'M': {'foo': {'S': 'bar'}}}])
916916

917917
def test_serialize_null(self):

tests/test_expressions.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def test_indexing(self):
325325
expression = condition.serialize(placeholder_names, expression_attribute_values)
326326
assert expression == "#0[0] = :0"
327327
assert placeholder_names == {'foo': '#0'}
328-
assert expression_attribute_values == {':0': {'S' : 'bar'}}
328+
assert expression_attribute_values == {':0': {'S': 'bar'}}
329329

330330
def test_invalid_indexing(self):
331331
with self.assertRaises(TypeError):
@@ -337,7 +337,17 @@ def test_double_indexing(self):
337337
expression = condition.serialize(placeholder_names, expression_attribute_values)
338338
assert expression == "#0[0][1] = :0"
339339
assert placeholder_names == {'foo': '#0'}
340-
assert expression_attribute_values == {':0': {'S' : 'bar'}}
340+
assert expression_attribute_values == {':0': {'S': 'bar'}}
341+
342+
def test_typed_list_indexing(self):
343+
class StringMap(MapAttribute):
344+
bar = UnicodeAttribute()
345+
condition = ListAttribute(attr_name='foo', of=StringMap)[0].bar == 'baz'
346+
placeholder_names, expression_attribute_values = {}, {}
347+
expression = condition.serialize(placeholder_names, expression_attribute_values)
348+
assert expression == "#0[0].#1 = :0"
349+
assert placeholder_names == {'foo': '#0', 'bar': '#1'}
350+
assert expression_attribute_values == {':0': {'S': 'baz'}}
341351

342352
def test_map_comparison(self):
343353
# Simulate initialization from inside an AttributeContainer

0 commit comments

Comments
 (0)