Skip to content

Commit db32163

Browse files
onursaticimvanderlee
authored andcommitted
support nested generic dataclasses
1 parent bc46a23 commit db32163

File tree

2 files changed

+58
-53
lines changed

2 files changed

+58
-53
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -44,15 +44,9 @@ class User:
4444
import warnings
4545
from enum import Enum
4646
from functools import lru_cache, partial
47+
from typing import Any, Callable, Dict, FrozenSet, Generic, List, Mapping
48+
from typing import NewType as typing_NewType
4749
from typing import (
48-
Any,
49-
Callable,
50-
Dict,
51-
FrozenSet,
52-
Generic,
53-
List,
54-
Mapping,
55-
NewType as typing_NewType,
5650
Optional,
5751
Sequence,
5852
Set,
@@ -150,8 +144,7 @@ def dataclass(
150144
frozen: bool = False,
151145
base_schema: Optional[Type[marshmallow.Schema]] = None,
152146
cls_frame: Optional[types.FrameType] = None,
153-
) -> Type[_U]:
154-
...
147+
) -> Type[_U]: ...
155148

156149

157150
@overload
@@ -164,8 +157,7 @@ def dataclass(
164157
frozen: bool = False,
165158
base_schema: Optional[Type[marshmallow.Schema]] = None,
166159
cls_frame: Optional[types.FrameType] = None,
167-
) -> Callable[[Type[_U]], Type[_U]]:
168-
...
160+
) -> Callable[[Type[_U]], Type[_U]]: ...
169161

170162

171163
# _cls should never be specified by keyword, so start it with an
@@ -224,15 +216,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
224216

225217

226218
@overload
227-
def add_schema(_cls: Type[_U]) -> Type[_U]:
228-
...
219+
def add_schema(_cls: Type[_U]) -> Type[_U]: ...
229220

230221

231222
@overload
232223
def add_schema(
233224
base_schema: Optional[Type[marshmallow.Schema]] = None,
234-
) -> Callable[[Type[_U]], Type[_U]]:
235-
...
225+
) -> Callable[[Type[_U]], Type[_U]]: ...
236226

237227

238228
@overload
@@ -241,8 +231,7 @@ def add_schema(
241231
base_schema: Optional[Type[marshmallow.Schema]] = None,
242232
cls_frame: Optional[types.FrameType] = None,
243233
stacklevel: int = 1,
244-
) -> Type[_U]:
245-
...
234+
) -> Type[_U]: ...
246235

247236

248237
def add_schema(_cls=None, base_schema=None, cls_frame=None, stacklevel=1):
@@ -293,8 +282,7 @@ def class_schema(
293282
*,
294283
globalns: Optional[Dict[str, Any]] = None,
295284
localns: Optional[Dict[str, Any]] = None,
296-
) -> Type[marshmallow.Schema]:
297-
...
285+
) -> Type[marshmallow.Schema]: ...
298286

299287

300288
@overload
@@ -304,8 +292,7 @@ def class_schema(
304292
clazz_frame: Optional[types.FrameType] = None,
305293
*,
306294
globalns: Optional[Dict[str, Any]] = None,
307-
) -> Type[marshmallow.Schema]:
308-
...
295+
) -> Type[marshmallow.Schema]: ...
309296

310297

311298
def class_schema(
@@ -463,7 +450,7 @@ def class_schema(
463450
if clazz_frame is not None:
464451
localns = clazz_frame.f_locals
465452
with _SchemaContext(globalns, localns):
466-
return _internal_class_schema(clazz, base_schema)
453+
return _internal_class_schema(clazz, base_schema, None)
467454

468455

469456
class _SchemaContext:
@@ -509,10 +496,17 @@ def top(self) -> _U:
509496
_schema_ctx_stack = _LocalStack[_SchemaContext]()
510497

511498

499+
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
500+
if _is_generic_alias_of_dataclass(clazz):
501+
clazz = typing_inspect.get_origin(clazz)
502+
return dataclasses.fields(clazz)
503+
504+
512505
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
513506
def _internal_class_schema(
514507
clazz: type,
515508
base_schema: Optional[Type[marshmallow.Schema]] = None,
509+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
516510
) -> Type[marshmallow.Schema]:
517511
schema_ctx = _schema_ctx_stack.top
518512

@@ -525,7 +519,7 @@ def _internal_class_schema(
525519
schema_ctx.seen_classes[clazz] = class_name
526520

527521
try:
528-
class_name, fields = _dataclass_name_and_fields(clazz)
522+
fields = _dataclass_fields(clazz)
529523
except TypeError: # Not a dataclass
530524
try:
531525
warnings.warn(
@@ -540,7 +534,9 @@ def _internal_class_schema(
540534
"****** WARNING ******"
541535
)
542536
created_dataclass: type = dataclasses.dataclass(clazz)
543-
return _internal_class_schema(created_dataclass, base_schema)
537+
return _internal_class_schema(
538+
created_dataclass, base_schema, generic_params_to_args
539+
)
544540
except Exception as exc:
545541
raise TypeError(
546542
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
@@ -556,6 +552,10 @@ def _internal_class_schema(
556552
# Determine whether we should include non-init fields
557553
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)
558554

555+
if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None:
556+
generic_params_to_args = _generic_params_to_args(clazz)
557+
558+
type_hints = _dataclass_type_hints(clazz, schema_ctx, generic_params_to_args)
559559
# Update the schema members to contain marshmallow fields instead of dataclass fields
560560

561561
if sys.version_info >= (3, 9):
@@ -577,13 +577,14 @@ def _internal_class_schema(
577577
_get_field_default(field),
578578
field.metadata,
579579
base_schema,
580+
generic_params_to_args,
580581
),
581582
)
582583
for field in fields
583584
if field.init or include_non_init
584585
)
585586

586-
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
587+
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
587588
return cast(Type[marshmallow.Schema], schema_class)
588589

589590

@@ -706,7 +707,7 @@ def _field_for_generic_type(
706707
),
707708
)
708709
return tuple_type(children, **metadata)
709-
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
710+
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
710711
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
711712
return dict_type(
712713
keys=_field_for_schema(arguments[0], base_schema=base_schema),
@@ -794,6 +795,7 @@ def field_for_schema(
794795
base_schema: Optional[Type[marshmallow.Schema]] = None,
795796
# FIXME: delete typ_frame from API?
796797
typ_frame: Optional[types.FrameType] = None,
798+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
797799
) -> marshmallow.fields.Field:
798800
"""
799801
Get a marshmallow Field corresponding to the given python type.
@@ -953,7 +955,7 @@ def _field_for_schema(
953955
nested_schema
954956
or forward_reference
955957
or _schema_ctx_stack.top.seen_classes.get(typ)
956-
or _internal_class_schema(typ, base_schema) # type: ignore[arg-type] # FIXME
958+
or _internal_class_schema(typ, base_schema, generic_params_to_args) # type: ignore [arg-type]
957959
)
958960

959961
return marshmallow.fields.Nested(nested, **metadata)
@@ -1007,35 +1009,38 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
10071009
)
10081010

10091011

1010-
# noinspection PyDataclass
1011-
def _dataclass_name_and_fields(
1012-
clazz: type,
1013-
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
1014-
if not _is_generic_alias_of_dataclass(clazz):
1015-
return clazz.__name__, dataclasses.fields(clazz)
1016-
1012+
def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]:
10171013
base_dataclass = typing_inspect.get_origin(clazz)
10181014
base_parameters = typing_inspect.get_parameters(base_dataclass)
10191015
type_arguments = typing_inspect.get_args(clazz)
1020-
params_to_args = dict(zip(base_parameters, type_arguments))
1021-
non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022-
(
1023-
f.name,
1024-
params_to_args.get(f.type, f.type),
1025-
dataclasses.field(
1026-
default=f.default,
1027-
# ignoring mypy: https://github.com/python/mypy/issues/6910
1028-
default_factory=f.default_factory, # type: ignore
1029-
init=f.init,
1030-
metadata=f.metadata,
1031-
),
1016+
return tuple(zip(base_parameters, type_arguments))
1017+
1018+
1019+
def _dataclass_type_hints(
1020+
clazz: type,
1021+
schema_ctx: _SchemaContext = None,
1022+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
1023+
) -> Mapping[str, type]:
1024+
if not _is_generic_alias_of_dataclass(clazz):
1025+
return get_type_hints(
1026+
clazz, globalns=schema_ctx.globalns, localns=schema_ctx.localns
10321027
)
1033-
for f in dataclasses.fields(base_dataclass)
1034-
]
1035-
non_generic_dataclass = dataclasses.make_dataclass(
1036-
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
1028+
# dataclass is generic
1029+
generic_type_hints = get_type_hints(
1030+
typing_inspect.get_origin(clazz),
1031+
globalns=schema_ctx.globalns,
1032+
localns=schema_ctx.localns,
10371033
)
1038-
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
1034+
generic_params_map = dict(generic_params_to_args if generic_params_to_args else {})
1035+
1036+
def _get_hint(_t: type) -> type:
1037+
if isinstance(_t, TypeVar):
1038+
return generic_params_map[_t]
1039+
return _t
1040+
1041+
return {
1042+
field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items()
1043+
}
10391044

10401045

10411046
def NewType(

tests/test_class_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
1515
from marshmallow.validate import Validator
1616

17-
from marshmallow_dataclass import class_schema, NewType
17+
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass
1818

1919

2020
class TestClassSchema(unittest.TestCase):

0 commit comments

Comments
 (0)