Skip to content

Commit 8a0f837

Browse files
committed
Fix tests after rebase
1 parent 8443336 commit 8a0f837

File tree

1 file changed

+53
-29
lines changed

1 file changed

+53
-29
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def dataclass(
202202
frozen: bool = False,
203203
base_schema: Optional[Type[marshmallow.Schema]] = None,
204204
cls_frame: Optional[types.FrameType] = None,
205-
) -> Type[_U]: ...
205+
) -> Type[_U]:
206+
...
206207

207208

208209
@overload
@@ -215,7 +216,8 @@ def dataclass(
215216
frozen: bool = False,
216217
base_schema: Optional[Type[marshmallow.Schema]] = None,
217218
cls_frame: Optional[types.FrameType] = None,
218-
) -> Callable[[Type[_U]], Type[_U]]: ...
219+
) -> Callable[[Type[_U]], Type[_U]]:
220+
...
219221

220222

221223
# _cls should never be specified by keyword, so start it with an
@@ -280,13 +282,15 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
280282

281283

282284
@overload
283-
def add_schema(_cls: Type[_U]) -> Type[_U]: ...
285+
def add_schema(_cls: Type[_U]) -> Type[_U]:
286+
...
284287

285288

286289
@overload
287290
def add_schema(
288291
base_schema: Optional[Type[marshmallow.Schema]] = None,
289-
) -> Callable[[Type[_U]], Type[_U]]: ...
292+
) -> Callable[[Type[_U]], Type[_U]]:
293+
...
290294

291295

292296
@overload
@@ -295,7 +299,8 @@ def add_schema(
295299
base_schema: Optional[Type[marshmallow.Schema]] = None,
296300
cls_frame: Optional[types.FrameType] = None,
297301
stacklevel: int = 1,
298-
) -> Type[_U]: ...
302+
) -> Type[_U]:
303+
...
299304

300305

301306
def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
@@ -348,7 +353,8 @@ def class_schema(
348353
*,
349354
globalns: Optional[Dict[str, Any]] = None,
350355
localns: Optional[Dict[str, Any]] = None,
351-
) -> Type[marshmallow.Schema]: ...
356+
) -> Type[marshmallow.Schema]:
357+
...
352358

353359

354360
@overload
@@ -358,7 +364,8 @@ def class_schema(
358364
clazz_frame: Optional[types.FrameType] = None,
359365
*,
360366
globalns: Optional[Dict[str, Any]] = None,
361-
) -> Type[marshmallow.Schema]: ...
367+
) -> Type[marshmallow.Schema]:
368+
...
362369

363370

364371
def class_schema(
@@ -573,7 +580,8 @@ def _internal_class_schema(
573580
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
574581
class_name = clazz._name or clazz.__origin__.__name__ # type: ignore[attr-defined]
575582
else:
576-
class_name = clazz.__name__
583+
# generic aliases do not have a __name__ prior python 3.10
584+
class_name = getattr(clazz, "__name__", repr(clazz))
577585

578586
schema_ctx.seen_classes[clazz] = class_name
579587

@@ -613,11 +621,20 @@ def _internal_class_schema(
613621
# Determine whether we should include non-init fields
614622
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)
615623

624+
# Update the schema members to contain marshmallow fields instead of dataclass fields
625+
type_hints = {}
626+
if not is_generic_type(clazz):
627+
type_hints = _get_type_hints(clazz, schema_ctx)
628+
616629
attributes.update(
617630
(
618631
field.name,
619-
field_for_schema(
620-
_get_field_type_hints(field, schema_ctx),
632+
_field_for_schema(
633+
(
634+
type_hints[field.name]
635+
if not is_generic_type(clazz)
636+
else _get_generic_type_hints(field.type, schema_ctx)
637+
),
621638
_get_field_default(field),
622639
field.metadata,
623640
base_schema,
@@ -710,7 +727,7 @@ def _field_for_generic_type(
710727
type_mapping = base_schema.TYPE_MAPPING if base_schema else {}
711728

712729
if origin in (list, List):
713-
child_type = field_for_schema(
730+
child_type = _field_for_schema(
714731
arguments[0],
715732
base_schema=base_schema,
716733
)
@@ -726,15 +743,15 @@ def _field_for_generic_type(
726743
):
727744
from . import collection_field
728745

729-
child_type = field_for_schema(
746+
child_type = _field_for_schema(
730747
arguments[0],
731748
base_schema=base_schema,
732749
)
733750
return collection_field.Sequence(cls_or_instance=child_type, **metadata)
734751
if origin in (set, Set):
735752
from . import collection_field
736753

737-
child_type = field_for_schema(
754+
child_type = _field_for_schema(
738755
arguments[0],
739756
base_schema=base_schema,
740757
)
@@ -744,7 +761,7 @@ def _field_for_generic_type(
744761
if origin in (frozenset, FrozenSet):
745762
from . import collection_field
746763

747-
child_type = field_for_schema(
764+
child_type = _field_for_schema(
748765
arguments[0],
749766
base_schema=base_schema,
750767
)
@@ -753,7 +770,7 @@ def _field_for_generic_type(
753770
)
754771
if origin in (tuple, Tuple):
755772
children = tuple(
756-
field_for_schema(
773+
_field_for_schema(
757774
arg,
758775
base_schema=base_schema,
759776
)
@@ -980,7 +997,7 @@ def _field_for_schema(
980997
)
981998
else:
982999
subtyp = Any
983-
return field_for_schema(subtyp, default, metadata, base_schema)
1000+
return _field_for_schema(subtyp, default, metadata, base_schema)
9841001

9851002
annotated_field = _field_for_annotated_type(typ, **metadata)
9861003
if annotated_field:
@@ -1081,30 +1098,37 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
10811098
)
10821099

10831100

1084-
def _get_field_type_hints(
1085-
field: dataclasses.Field,
1086-
schema_ctx: Optional[_SchemaContext] = None,
1087-
) -> type:
1088-
"""typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works."""
1089-
1090-
class X:
1091-
x: field.type # type: ignore[name-defined]
1092-
1101+
def _get_type_hints(
1102+
obj,
1103+
schema_ctx: _SchemaContext,
1104+
):
10931105
if sys.version_info >= (3, 9):
10941106
type_hints = get_type_hints(
1095-
X,
1107+
obj,
10961108
globalns=schema_ctx.globalns,
10971109
localns=schema_ctx.localns,
10981110
include_extras=True,
1099-
)["x"]
1111+
)
11001112
else:
11011113
type_hints = get_type_hints(
1102-
X, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1103-
)["x"]
1114+
obj, globalns=schema_ctx.globalns, localns=schema_ctx.localns
1115+
)
11041116

11051117
return type_hints
11061118

11071119

1120+
def _get_generic_type_hints(
1121+
obj,
1122+
schema_ctx: _SchemaContext,
1123+
) -> type:
1124+
"""typing.get_type_hints doesn't work with generic aliasses. But this 'hack' works."""
1125+
1126+
class X:
1127+
x: obj # type: ignore[name-defined]
1128+
1129+
return _get_type_hints(X, schema_ctx)["x"]
1130+
1131+
11081132
def _is_generic_alias(clazz: type) -> bool:
11091133
"""
11101134
Check if given class is a generic alias of a class is

0 commit comments

Comments
 (0)