Skip to content

Commit 8b82276

Browse files
onursaticimvanderlee
authored andcommitted
add test for repeated fields, fix __name__ attr for py<3.10
1 parent db32163 commit 8b82276

File tree

2 files changed

+107
-18
lines changed

2 files changed

+107
-18
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,7 @@ def _internal_class_schema(
584584
if field.init or include_non_init
585585
)
586586

587-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
587+
schema_class = type(_name, (_base_schema(clazz, base_schema),), attributes)
588588
return cast(Type[marshmallow.Schema], schema_class)
589589

590590

@@ -602,6 +602,7 @@ def _field_by_supertype(
602602
newtype_supertype: Type,
603603
metadata: dict,
604604
base_schema: Optional[Type[marshmallow.Schema]],
605+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
605606
) -> marshmallow.fields.Field:
606607
"""
607608
Return a new field for fields based on a super field. (Usually spawned from NewType)
@@ -632,6 +633,7 @@ def _field_by_supertype(
632633
metadata=metadata,
633634
default=default,
634635
base_schema=base_schema,
636+
generic_params_to_args=generic_params_to_args,
635637
)
636638

637639

@@ -655,6 +657,7 @@ def _generic_type_add_any(typ: type) -> type:
655657
def _field_for_generic_type(
656658
typ: type,
657659
base_schema: Optional[Type[marshmallow.Schema]],
660+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
658661
**metadata: Any,
659662
) -> Optional[marshmallow.fields.Field]:
660663
"""
@@ -667,7 +670,11 @@ def _field_for_generic_type(
667670
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
668671

669672
if origin in (list, List):
670-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
673+
child_type = field_for_schema(
674+
arguments[0],
675+
base_schema=base_schema,
676+
generic_params_to_args=generic_params_to_args,
677+
)
671678
list_type = cast(
672679
Type[marshmallow.fields.List],
673680
type_mapping.get(List, marshmallow.fields.List),
@@ -680,25 +687,42 @@ def _field_for_generic_type(
680687
):
681688
from . import collection_field
682689

683-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
690+
child_type = field_for_schema(
691+
arguments[0],
692+
base_schema=base_schema,
693+
generic_params_to_args=generic_params_to_args,
694+
)
684695
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
685696
if origin in (set, Set):
686697
from . import collection_field
687698

688-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
699+
child_type = field_for_schema(
700+
arguments[0],
701+
base_schema=base_schema,
702+
generic_params_to_args=generic_params_to_args,
703+
)
689704
return collection_field.Set(
690705
cls_or_instance=child_type, frozen=False, **metadata
691706
)
692707
if origin in (frozenset, FrozenSet):
693708
from . import collection_field
694709

695-
child_type = _field_for_schema(arguments[0], base_schema=base_schema)
710+
child_type = field_for_schema(
711+
arguments[0],
712+
base_schema=base_schema,
713+
generic_params_to_args=generic_params_to_args,
714+
)
696715
return collection_field.Set(
697716
cls_or_instance=child_type, frozen=True, **metadata
698717
)
699718
if origin in (tuple, Tuple):
700719
children = tuple(
701-
_field_for_schema(arg, base_schema=base_schema) for arg in arguments
720+
field_for_schema(
721+
arg,
722+
base_schema=base_schema,
723+
generic_params_to_args=generic_params_to_args,
724+
)
725+
for arg in arguments
702726
)
703727
tuple_type = cast(
704728
Type[marshmallow.fields.Tuple],
@@ -710,8 +734,16 @@ def _field_for_generic_type(
710734
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
711735
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
712736
return dict_type(
713-
keys=_field_for_schema(arguments[0], base_schema=base_schema),
714-
values=_field_for_schema(arguments[1], base_schema=base_schema),
737+
keys=field_for_schema(
738+
arguments[0],
739+
base_schema=base_schema,
740+
generic_params_to_args=generic_params_to_args,
741+
),
742+
values=field_for_schema(
743+
arguments[1],
744+
base_schema=base_schema,
745+
generic_params_to_args=generic_params_to_args,
746+
),
715747
**metadata,
716748
)
717749

@@ -768,6 +800,7 @@ def _field_for_union_type(
768800
subtypes[0],
769801
metadata=metadata,
770802
base_schema=base_schema,
803+
generic_params_to_args=generic_params_to_args,
771804
)
772805
from . import union_field
773806

@@ -779,6 +812,7 @@ def _field_for_union_type(
779812
subtyp,
780813
metadata={"required": True},
781814
base_schema=base_schema,
815+
generic_params_to_args=generic_params_to_args,
782816
),
783817
)
784818
for subtyp in subtypes
@@ -818,14 +852,17 @@ def field_for_schema(
818852
<class 'marshmallow.fields.Url'>
819853
"""
820854
with _SchemaContext(localns=typ_frame.f_locals if typ_frame is not None else None):
821-
return _field_for_schema(typ, default, metadata, base_schema)
855+
return _field_for_schema(
856+
typ, default, metadata, base_schema, generic_params_to_args
857+
)
822858

823859

824860
def _field_for_schema(
825861
typ: type,
826862
default: Any = marshmallow.missing,
827863
metadata: Optional[Mapping[str, Any]] = None,
828864
base_schema: Optional[Type[marshmallow.Schema]] = None,
865+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
829866
) -> marshmallow.fields.Field:
830867
"""
831868
Get a marshmallow Field corresponding to the given python type.
@@ -913,7 +950,9 @@ def _field_for_schema(
913950
)
914951
else:
915952
subtyp = Any
916-
return _field_for_schema(subtyp, default, metadata, base_schema)
953+
return field_for_schema(
954+
subtyp, default, metadata, base_schema, generic_params_to_args
955+
)
917956

918957
annotated_field = _field_for_annotated_type(typ, **metadata)
919958
if annotated_field:
@@ -938,6 +977,7 @@ def _field_for_schema(
938977
newtype_supertype=newtype_supertype,
939978
metadata=metadata,
940979
base_schema=base_schema,
980+
generic_params_to_args=generic_params_to_args,
941981
)
942982

943983
# enumerations

tests/test_class_schema.py

Lines changed: 57 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import inspect
22
import typing
33
import unittest
4-
from typing import Any, cast, TYPE_CHECKING
4+
from typing import TYPE_CHECKING, Any, cast
55
from uuid import UUID
66

77
try:
@@ -10,11 +10,14 @@
1010
from typing_extensions import Final, Literal # type: ignore[assignment]
1111

1212
import dataclasses
13+
1314
from marshmallow import Schema, ValidationError
14-
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
15+
from marshmallow.fields import UUID as UUIDField
16+
from marshmallow.fields import Field, Integer
17+
from marshmallow.fields import List as ListField
1518
from marshmallow.validate import Validator
1619

17-
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass
20+
from marshmallow_dataclass import NewType, _is_generic_alias_of_dataclass, class_schema
1821

1922

2023
class TestClassSchema(unittest.TestCase):
@@ -465,24 +468,70 @@ class SimpleGeneric(typing.Generic[T]):
465468
data: T
466469

467470
@dataclasses.dataclass
468-
class Nested:
471+
class NestedFixed:
469472
data: SimpleGeneric[int]
470473

474+
@dataclasses.dataclass
475+
class NestedGeneric(typing.Generic[T]):
476+
data: SimpleGeneric[T]
477+
478+
self.assertTrue(_is_generic_alias_of_dataclass(SimpleGeneric[int]))
479+
self.assertFalse(_is_generic_alias_of_dataclass(SimpleGeneric))
480+
471481
schema_s = class_schema(SimpleGeneric[str])()
472482
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
473483
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
474484
with self.assertRaises(ValidationError):
475485
schema_s.load({"data": 2})
476486

477-
schema_n = class_schema(Nested)()
487+
schema_nested = class_schema(NestedFixed)()
488+
self.assertEqual(
489+
NestedFixed(data=SimpleGeneric(1)),
490+
schema_nested.load({"data": {"data": 1}}),
491+
)
492+
self.assertEqual(
493+
schema_nested.dump(NestedFixed(data=SimpleGeneric(data=1))),
494+
{"data": {"data": 1}},
495+
)
496+
with self.assertRaises(ValidationError):
497+
schema_nested.load({"data": {"data": "str"}})
498+
499+
schema_nested_generic = class_schema(NestedGeneric[int])()
478500
self.assertEqual(
479-
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
501+
NestedGeneric(data=SimpleGeneric(1)),
502+
schema_nested_generic.load({"data": {"data": 1}}),
480503
)
481504
self.assertEqual(
482-
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
505+
schema_nested_generic.dump(NestedGeneric(data=SimpleGeneric(data=1))),
506+
{"data": {"data": 1}},
483507
)
484508
with self.assertRaises(ValidationError):
485-
schema_n.load({"data": {"data": "str"}})
509+
schema_nested_generic.load({"data": {"data": "str"}})
510+
511+
def test_generic_dataclass_repeated_fields(self):
512+
T = typing.TypeVar("T")
513+
514+
@dataclasses.dataclass
515+
class AA:
516+
a: int
517+
518+
@dataclasses.dataclass
519+
class BB(typing.Generic[T]):
520+
b: T
521+
522+
@dataclasses.dataclass
523+
class Nested:
524+
x: BB[float]
525+
z: BB[float]
526+
# if y is the first field in this class, deserialisation will fail.
527+
# see https://github.com/lovasoa/marshmallow_dataclass/pull/172#issuecomment-1334024027
528+
y: BB[AA]
529+
530+
schema_nested = class_schema(Nested)()
531+
self.assertEqual(
532+
Nested(x=BB(b=1), z=BB(b=1), y=BB(b=AA(1))),
533+
schema_nested.load({"x": {"b": 1}, "z": {"b": 1}, "y": {"b": {"a": 1}}}),
534+
)
486535

487536

488537
if __name__ == "__main__":

0 commit comments

Comments
 (0)