Skip to content

Commit 52e25f8

Browse files
onursaticimvanderlee
authored andcommitted
add test for repeated fields, fix __name__ attr for py<3.10
1 parent f3da098 commit 52e25f8

File tree

2 files changed

+101
-18
lines changed

2 files changed

+101
-18
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 44 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -394,7 +394,10 @@ def _internal_class_schema(
394394
clazz_frame: Optional[types.FrameType] = None,
395395
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
396396
) -> Type[marshmallow.Schema]:
397-
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
397+
# generic aliases do not have a __name__ prior python 3.10
398+
_name = getattr(clazz, "__name__", repr(clazz))
399+
400+
_RECURSION_GUARD.seen_classes[clazz] = _name
398401
try:
399402
fields = _dataclass_fields(clazz)
400403
except TypeError: # Not a dataclass
@@ -450,7 +453,7 @@ def _internal_class_schema(
450453
if field.init or include_non_init
451454
)
452455

453-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
456+
schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes)
454457
return cast(Type[marshmallow.Schema], schema_class)
455458

456459

@@ -469,6 +472,7 @@ def _field_by_supertype(
469472
metadata: dict,
470473
base_schema: Optional[Type[marshmallow.Schema]],
471474
typ_frame: Optional[types.FrameType],
475+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
472476
) -> marshmallow.fields.Field:
473477
"""
474478
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -500,6 +504,7 @@ def _field_by_supertype(
500504
default=default,
501505
base_schema=base_schema,
502506
typ_frame=typ_frame,
507+
generic_params_to_args=generic_params_to_args,
503508
)
504509

505510

@@ -524,6 +529,7 @@ def _field_for_generic_type(
524529
typ: type,
525530
base_schema: Optional[Type[marshmallow.Schema]],
526531
typ_frame: Optional[types.FrameType],
532+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
527533
**metadata: Any,
528534
) -> Optional[marshmallow.fields.Field]:
529535
"""
@@ -537,7 +543,10 @@ def _field_for_generic_type(
537543

538544
if origin in (list, List):
539545
child_type = field_for_schema(
540-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
546+
arguments[0],
547+
base_schema=base_schema,
548+
typ_frame=typ_frame,
549+
generic_params_to_args=generic_params_to_args,
541550
)
542551
list_type = cast(
543552
Type[marshmallow.fields.List],
@@ -552,14 +561,20 @@ def _field_for_generic_type(
552561
from . import collection_field
553562

554563
child_type = field_for_schema(
555-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
564+
arguments[0],
565+
base_schema=base_schema,
566+
typ_frame=typ_frame,
567+
generic_params_to_args=generic_params_to_args,
556568
)
557569
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
558570
if origin in (set, Set):
559571
from . import collection_field
560572

561573
child_type = field_for_schema(
562-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
574+
arguments[0],
575+
base_schema=base_schema,
576+
typ_frame=typ_frame,
577+
generic_params_to_args=generic_params_to_args,
563578
)
564579
return collection_field.Set(
565580
cls_or_instance=child_type, frozen=False, **metadata
@@ -568,14 +583,22 @@ def _field_for_generic_type(
568583
from . import collection_field
569584

570585
child_type = field_for_schema(
571-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
586+
arguments[0],
587+
base_schema=base_schema,
588+
typ_frame=typ_frame,
589+
generic_params_to_args=generic_params_to_args,
572590
)
573591
return collection_field.Set(
574592
cls_or_instance=child_type, frozen=True, **metadata
575593
)
576594
if origin in (tuple, Tuple):
577595
children = tuple(
578-
field_for_schema(arg, base_schema=base_schema, typ_frame=typ_frame)
596+
field_for_schema(
597+
arg,
598+
base_schema=base_schema,
599+
typ_frame=typ_frame,
600+
generic_params_to_args=generic_params_to_args,
601+
)
579602
for arg in arguments
580603
)
581604
tuple_type = cast(
@@ -589,10 +612,16 @@ def _field_for_generic_type(
589612
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
590613
return dict_type(
591614
keys=field_for_schema(
592-
arguments[0], base_schema=base_schema, typ_frame=typ_frame
615+
arguments[0],
616+
base_schema=base_schema,
617+
typ_frame=typ_frame,
618+
generic_params_to_args=generic_params_to_args,
593619
),
594620
values=field_for_schema(
595-
arguments[1], base_schema=base_schema, typ_frame=typ_frame
621+
arguments[1],
622+
base_schema=base_schema,
623+
typ_frame=typ_frame,
624+
generic_params_to_args=generic_params_to_args,
596625
),
597626
**metadata,
598627
)
@@ -610,6 +639,7 @@ def _field_for_generic_type(
610639
metadata=metadata,
611640
base_schema=base_schema,
612641
typ_frame=typ_frame,
642+
generic_params_to_args=generic_params_to_args,
613643
)
614644
from . import union_field
615645

@@ -622,6 +652,7 @@ def _field_for_generic_type(
622652
metadata={"required": True},
623653
base_schema=base_schema,
624654
typ_frame=typ_frame,
655+
generic_params_to_args=generic_params_to_args,
625656
),
626657
)
627658
for subtyp in subtypes
@@ -730,7 +761,9 @@ def field_for_schema(
730761
)
731762
else:
732763
subtyp = Any
733-
return field_for_schema(subtyp, default, metadata, base_schema, typ_frame)
764+
return field_for_schema(
765+
subtyp, default, metadata, base_schema, typ_frame, generic_params_to_args
766+
)
734767

735768
# Generic types
736769
generic_field = _field_for_generic_type(typ, base_schema, typ_frame, **metadata)
@@ -748,6 +781,7 @@ def field_for_schema(
748781
metadata=metadata,
749782
base_schema=base_schema,
750783
typ_frame=typ_frame,
784+
generic_params_to_args=generic_params_to_args,
751785
)
752786

753787
# enumerations

tests/test_class_schema.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import typing
33
import unittest
4-
from typing import Any, cast, TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any, cast
55
from uuid import UUID
66

77
try:
@@ -10,11 +10,14 @@
1010
from typing_extensions import Final, Literal # type: ignore[assignment]
1111

1212
import dataclasses
13+
1314
from marshmallow import Schema, ValidationError
14-
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
15+
from marshmallow.fields import UUID as UUIDField
16+
from marshmallow.fields import Field, Integer
17+
from marshmallow.fields import List as ListField
1518
from marshmallow.validate import Validator
1619

17-
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass
20+
from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema
1821

1922

2023
class TestClassSchema(unittest.TestCase):
@@ -465,24 +468,70 @@ class SimpleGeneric(typing.Generic[T]):
465468
data: T
466469

467470
@dataclasses.dataclass
468-
class Nested:
471+
class NestedFixed:
469472
data: SimpleGeneric[int]
470473

474+
@dataclasses.dataclass
475+
class NestedGeneric(typing.Generic[T]):
476+
data: SimpleGeneric[T]
477+
478+
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
479+
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
480+
471481
schema_s = class_schema(SimpleGeneric[str])()
472482
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
473483
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
474484
with self.assertRaises(ValidationError):
475485
schema_s.load({"data": 2})
476486

477-
schema_n = class_schema(Nested)()
487+
schema_nested = class_schema(NestedFixed)()
488+
self.assertEqual(
489+
NestedFixed(data=SimpleGeneric(1)),
490+
schema_nested.load({"data": {"data": 1}}),
491+
)
492+
self.assertEqual(
493+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
494+
{"data": {"data": 1}},
495+
)
496+
with self.assertRaises(ValidationError):
497+
schema_nested.load({"data": {"data": "str"}})
498+
499+
schema_nested_generic = class_schema(NestedGeneric[int])()
478500
self.assertEqual(
479-
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
501+
NestedGeneric(data=SimpleGeneric(1)),
502+
schema_nested_generic.load({"data": {"data": 1}}),
480503
)
481504
self.assertEqual(
482-
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
505+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
506+
{"data": {"data": 1}},
483507
)
484508
with self.assertRaises(ValidationError):
485-
schema_n.load({"data": {"data": "str"}})
509+
schema_nested_generic.load({"data": {"data": "str"}})
510+
511+
def test_generic_dataclass_repeated_fields(self):
512+
T = typing.TypeVar("T")
513+
514+
@dataclasses.dataclass
515+
class AA:
516+
a: int
517+
518+
@dataclasses.dataclass
519+
class BB(typing.Generic[T]):
520+
b: T
521+
522+
@dataclasses.dataclass
523+
class Nested:
524+
x: BB[float]
525+
z: BB[float]
526+
# if y is the first field in this class, deserialisation will fail.
527+
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
528+
y: BB[AA]
529+
530+
schema_nested = class_schema(Nested)()
531+
self.assertEqual(
532+
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
533+
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
534+
)
486535

487536

488537
if __name__ == "__main__":

0 commit comments

Comments
 (0)