Skip to content
This repository was archived by the owner on Jan 28, 2022. It is now read-only.

Commit db293c2

Browse files
authored
Merge pull request #65 from lmignon/fix-load-nested-many
FIx load of nested many models
2 parents 908a786 + bfa8588 commit db293c2

File tree

2 files changed

+21
-2
lines changed

2 files changed

+21
-2
lines changed

marshmallow_objects/models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525

2626
@marshmallow.post_load
2727
def __make_object__(self, data, **kwargs):
28-
return self.__model_class__(__post_load__=True, __schema__=self, **data)
28+
data["many"] = kwargs.pop("many", None)
29+
return self.__model_class__(
30+
__post_load__=True, __schema__=self, **data)
2931

3032

3133
class ModelMeta(type):
@@ -69,6 +71,7 @@ def __new__(mcs, name, parents, dct):
6971

7072
def __call__(cls, *args, **kwargs):
7173
if kwargs.pop('__post_load__', False):
74+
kwargs.pop("many")
7275
schema = kwargs.pop('__schema__')
7376
obj = cls.__new__(cls, *args, **kwargs)
7477
obj.__dump_lock__ = threading.RLock()
@@ -83,7 +86,8 @@ def __call__(cls, *args, **kwargs):
8386
else:
8487
context = kwargs.pop('context', None)
8588
partial = kwargs.pop('partial', None)
86-
obj = cls.load(kwargs, context=context, partial=partial)
89+
many = kwargs.pop("many", None)
90+
obj = cls.load(kwargs, many=many, context=context, partial=partial)
8791
return obj
8892

8993

@@ -92,6 +96,8 @@ def __init__(self, nested, **kwargs):
9296
super(NestedModel, self).__init__(nested.__schema_class__, **kwargs)
9397

9498
def _deserialize(self, value, attr, data, **kwargs):
99+
if self.many and value and isinstance(value[0], Model):
100+
return value
95101
if isinstance(value, Model):
96102
return value
97103
return super(NestedModel, self)._deserialize(value, attr, data,

tests/test_models.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,19 @@ def test_nested_dict_many(self):
148148
c = C(a=[dict(test_field='1'), dict(test_field='2')])
149149
self.assertEqual(2, len(c.a))
150150

151+
def test_nested_model_many(self):
152+
c = C(a=[A(test_field='1'), A(test_field='2')])
153+
self.assertEqual(2, len(c.a))
154+
155+
def test_load_model_many(self):
156+
a_list = A.load(
157+
[dict(test_field='1'), dict(test_field='2')],
158+
many=True
159+
)
160+
self.assertEqual(2, len(a_list))
161+
self.assertEqual('1', a_list[0].test_field)
162+
self.assertEqual('2', a_list[1].test_field)
163+
151164
def test_partial(self):
152165
self.assertRaises(marshmallow.ValidationError, B)
153166
b = B(partial=True)

0 commit comments

Comments
 (0)