Skip to content

Commit f3da098

Browse files
onursaticimvanderlee
authored andcommitted
support nested generic dataclasses
1 parent b0ce65b commit f3da098

File tree

2 files changed

+55
-55
lines changed

2 files changed

+55
-55
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 54 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class User:
3434
})
3535
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
3636
"""
37+
3738
import collections.abc
3839
import dataclasses
3940
import inspect
@@ -43,14 +44,11 @@ class User:
4344
import warnings
4445
from enum import Enum
4546
from functools import lru_cache, partial
47+
from typing import Any, Callable, Dict, FrozenSet, List, Mapping
48+
from typing import NewType as typing_NewType
4649
from typing import (
47-
Any,
48-
Callable,
49-
Dict,
50-
List,
51-
Mapping,
52-
NewType as typing_NewType,
5350
Optional,
51+
Sequence,
5452
Set,
5553
Tuple,
5654
Type,
@@ -59,16 +57,13 @@ class User:
5957
cast,
6058
get_type_hints,
6159
overload,
62-
Sequence,
63-
FrozenSet,
6460
)
6561

6662
import marshmallow
6763
import typing_inspect
6864

6965
from marshmallow_dataclass.lazy_class_attribute import lazy_class_attribute
7066

71-
7267
if sys.version_info >= (3, 11):
7368
from typing import dataclass_transform
7469
elif sys.version_info >= (3, 7):
@@ -105,8 +100,7 @@ def dataclass(
105100
frozen: bool = False,
106101
base_schema: Optional[Type[marshmallow.Schema]] = None,
107102
cls_frame: Optional[types.FrameType] = None,
108-
) -> Type[_U]:
109-
...
103+
) -> Type[_U]: ...
110104

111105

112106
@overload
@@ -119,8 +113,7 @@ def dataclass(
119113
frozen: bool = False,
120114
base_schema: Optional[Type[marshmallow.Schema]] = None,
121115
cls_frame: Optional[types.FrameType] = None,
122-
) -> Callable[[Type[_U]], Type[_U]]:
123-
...
116+
) -> Callable[[Type[_U]], Type[_U]]: ...
124117

125118

126119
# _cls should never be specified by keyword, so start it with an
@@ -179,24 +172,21 @@ def dataclass(
179172

180173

181174
@overload
182-
def add_schema(_cls: Type[_U]) -> Type[_U]:
183-
...
175+
def add_schema(_cls: Type[_U]) -> Type[_U]: ...
184176

185177

186178
@overload
187179
def add_schema(
188180
base_schema: Optional[Type[marshmallow.Schema]] = None,
189-
) -> Callable[[Type[_U]], Type[_U]]:
190-
...
181+
) -> Callable[[Type[_U]], Type[_U]]: ...
191182

192183

193184
@overload
194185
def add_schema(
195186
_cls: Type[_U],
196187
base_schema: Optional[Type[marshmallow.Schema]] = None,
197188
cls_frame: Optional[types.FrameType] = None,
198-
) -> Type[_U]:
199-
...
189+
) -> Type[_U]: ...
200190

201191

202192
def add_schema(_cls=None, base_schema=None, cls_frame=None):
@@ -386,20 +376,27 @@ def class_schema(
386376
del current_frame
387377
_RECURSION_GUARD.seen_classes = {}
388378
try:
389-
return _internal_class_schema(clazz, base_schema, clazz_frame)
379+
return _internal_class_schema(clazz, base_schema, clazz_frame, None)
390380
finally:
391381
_RECURSION_GUARD.seen_classes.clear()
392382

393383

384+
def _dataclass_fields(clazz: type) -> Tuple[dataclasses.Field, ...]:
385+
if _is_generic_alias_of_dataclass(clazz):
386+
clazz = typing_inspect.get_origin(clazz)
387+
return dataclasses.fields(clazz)
388+
389+
394390
@lru_cache(maxsize=MAX_CLASS_SCHEMA_CACHE_SIZE)
395391
def _internal_class_schema(
396392
clazz: type,
397393
base_schema: Optional[Type[marshmallow.Schema]] = None,
398394
clazz_frame: Optional[types.FrameType] = None,
395+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
399396
) -> Type[marshmallow.Schema]:
400397
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
401398
try:
402-
class_name, fields = _dataclass_name_and_fields(clazz)
399+
fields = _dataclass_fields(clazz)
403400
except TypeError: # Not a dataclass
404401
try:
405402
warnings.warn(
@@ -414,7 +411,9 @@ def _internal_class_schema(
414411
"****** WARNING ******"
415412
)
416413
created_dataclass: type = dataclasses.dataclass(clazz)
417-
return _internal_class_schema(created_dataclass, base_schema, clazz_frame)
414+
return _internal_class_schema(
415+
created_dataclass, base_schema, clazz_frame, generic_params_to_args
416+
)
418417
except Exception as exc:
419418
raise TypeError(
420419
f"{getattr(clazz, '__name__', repr(clazz))} is not a dataclass and cannot be turned into one."
@@ -430,10 +429,11 @@ def _internal_class_schema(
430429
# Determine whether we should include non-init fields
431430
include_non_init = getattr(getattr(clazz, "Meta", None), "include_non_init", False)
432431

432+
if _is_generic_alias_of_dataclass(clazz) and generic_params_to_args is None:
433+
generic_params_to_args = _generic_params_to_args(clazz)
434+
435+
type_hints = _dataclass_type_hints(clazz, clazz_frame, generic_params_to_args)
433436
# Update the schema members to contain marshmallow fields instead of dataclass fields
434-
type_hints = get_type_hints(
435-
clazz, localns=clazz_frame.f_locals if clazz_frame else None
436-
)
437437
attributes.update(
438438
(
439439
field.name,
@@ -443,13 +443,14 @@ def _internal_class_schema(
443443
field.metadata,
444444
base_schema,
445445
clazz_frame,
446+
generic_params_to_args,
446447
),
447448
)
448449
for field in fields
449450
if field.init or include_non_init
450451
)
451452

452-
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
453+
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
453454
return cast(Type[marshmallow.Schema], schema_class)
454455

455456

@@ -584,7 +585,7 @@ def _field_for_generic_type(
584585
),
585586
)
586587
return tuple_type(children, **metadata)
587-
elif origin in (dict, Dict, collections.abc.Mapping, Mapping):
588+
if origin in (dict, Dict, collections.abc.Mapping, Mapping):
588589
dict_type = type_mapping.get(Dict, marshmallow.fields.Dict)
589590
return dict_type(
590591
keys=field_for_schema(
@@ -636,6 +637,7 @@ def field_for_schema(
636637
metadata: Optional[Mapping[str, Any]] = None,
637638
base_schema: Optional[Type[marshmallow.Schema]] = None,
638639
typ_frame: Optional[types.FrameType] = None,
640+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
639641
) -> marshmallow.fields.Field:
640642
"""
641643
Get a marshmallow Field corresponding to the given python type.
@@ -769,7 +771,7 @@ def field_for_schema(
769771
nested_schema
770772
or forward_reference
771773
or _RECURSION_GUARD.seen_classes.get(typ)
772-
or _internal_class_schema(typ, base_schema, typ_frame) # type: ignore [arg-type]
774+
or _internal_class_schema(typ, base_schema, typ_frame, generic_params_to_args) # type: ignore [arg-type]
773775
)
774776

775777
return marshmallow.fields.Nested(nested, **metadata)
@@ -823,35 +825,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
823825
)
824826

825827

826-
# noinspection PyDataclass
827-
def _dataclass_name_and_fields(
828-
clazz: type,
829-
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
830-
if not _is_generic_alias_of_dataclass(clazz):
831-
return clazz.__name__, dataclasses.fields(clazz)
832-
828+
def _generic_params_to_args(clazz: type) -> Tuple[Tuple[type, type], ...]:
833829
base_dataclass = typing_inspect.get_origin(clazz)
834830
base_parameters = typing_inspect.get_parameters(base_dataclass)
835831
type_arguments = typing_inspect.get_args(clazz)
836-
params_to_args = dict(zip(base_parameters, type_arguments))
837-
non_generic_fields = [ # swap generic typed fields with types in given type arguments
838-
(
839-
f.name,
840-
params_to_args.get(f.type, f.type),
841-
dataclasses.field(
842-
default=f.default,
843-
# ignoring mypy: https://github.com/python/mypy/issues/6910
844-
default_factory=f.default_factory, # type: ignore
845-
init=f.init,
846-
metadata=f.metadata,
847-
),
848-
)
849-
for f in dataclasses.fields(base_dataclass)
850-
]
851-
non_generic_dataclass = dataclasses.make_dataclass(
852-
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
853-
)
854-
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
832+
return tuple(zip(base_parameters, type_arguments))
833+
834+
835+
def _dataclass_type_hints(
836+
clazz: type,
837+
clazz_frame: types.FrameType = None,
838+
generic_params_to_args: Optional[Tuple[Tuple[type, type], ...]] = None,
839+
) -> Mapping[str, type]:
840+
localns = clazz_frame.f_locals if clazz_frame else None
841+
if not _is_generic_alias_of_dataclass(clazz):
842+
return get_type_hints(clazz, localns=localns)
843+
# dataclass is generic
844+
generic_type_hints = get_type_hints(typing_inspect.get_origin(clazz), localns)
845+
generic_params_map = dict(generic_params_to_args if generic_params_to_args else {})
846+
847+
def _get_hint(_t: type) -> type:
848+
if isinstance(_t, TypeVar):
849+
return generic_params_map[_t]
850+
return _t
851+
852+
return {
853+
field_name: _get_hint(typ) for field_name, typ in generic_type_hints.items()
854+
}
855855

856856

857857
def NewType(

tests/test_class_schema.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from marshmallow.fields import Field, UUID as UUIDField, List as ListField, Integer
1515
from marshmallow.validate import Validator
1616

17-
from marshmallow_dataclass import class_schema, NewType
17+
from marshmallow_dataclass import class_schema, NewType, _is_generic_alias_of_dataclass
1818

1919

2020
class TestClassSchema(unittest.TestCase):

0 commit comments

Comments
 (0)