Skip to content

Commit b0ce65b

Browse files
onursaticimvanderlee
authored andcommitted
support generic dataclasses
1 parent d6396c1 commit b0ce65b

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,9 @@ def class_schema(
374374
>>> class_schema(Custom)().load({})
375375
Custom(name=None)
376376
"""
377-
if not dataclasses.is_dataclass(clazz):
377+
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
378+
clazz
379+
):
378380
clazz = dataclasses.dataclass(clazz)
379381
if not clazz_frame:
380382
current_frame = inspect.currentframe()
@@ -397,8 +399,7 @@ def _internal_class_schema(
397399
) -> Type[marshmallow.Schema]:
398400
_RECURSION_GUARD.seen_classes[clazz] = clazz.__name__
399401
try:
400-
# noinspection PyDataclass
401-
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
402+
class_name, fields = _dataclass_name_and_fields(clazz)
402403
except TypeError: # Not a dataclass
403404
try:
404405
warnings.warn(
@@ -448,7 +449,7 @@ def _internal_class_schema(
448449
if field.init or include_non_init
449450
)
450451

451-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
452+
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
452453
return cast(Type[marshmallow.Schema], schema_class)
453454

454455

@@ -812,6 +813,47 @@ def _get_field_default(field: dataclasses.Field):
812813
return field.default
813814

814815

816+
def _is_generic_alias_of_dataclass(clazz: type) -> bool:
817+
"""
818+
Check if given class is a generic alias of a dataclass, if the dataclass is
819+
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
820+
"""
821+
return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass(
822+
typing_inspect.get_origin(clazz)
823+
)
824+
825+
826+
# noinspection PyDataclass
827+
def _dataclass_name_and_fields(
828+
clazz: type,
829+
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
830+
if not _is_generic_alias_of_dataclass(clazz):
831+
return clazz.__name__, dataclasses.fields(clazz)
832+
833+
base_dataclass = typing_inspect.get_origin(clazz)
834+
base_parameters = typing_inspect.get_parameters(base_dataclass)
835+
type_arguments = typing_inspect.get_args(clazz)
836+
params_to_args = dict(zip(base_parameters, type_arguments))
837+
non_generic_fields = [ # swap generic typed fields with types in given type arguments
838+
(
839+
f.name,
840+
params_to_args.get(f.type, f.type),
841+
dataclasses.field(
842+
default=f.default,
843+
# ignoring mypy: https://github.com/python/mypy/issues/6910
844+
default_factory=f.default_factory, # type: ignore
845+
init=f.init,
846+
metadata=f.metadata,
847+
),
848+
)
849+
for f in dataclasses.fields(base_dataclass)
850+
]
851+
non_generic_dataclass = dataclasses.make_dataclass(
852+
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
853+
)
854+
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
855+
856+
815857
def NewType(
816858
name: str,
817859
typ: Type[_U],

tests/test_class_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,33 @@ class Meta:
457457
self.assertNotIn("no_init", class_schema(NoInit)().fields)
458458
self.assertIn("no_init", class_schema(Init)().fields)
459459

460+
def test_generic_dataclass(self):
461+
T = typing.TypeVar("T")
462+
463+
@dataclasses.dataclass
464+
class SimpleGeneric(typing.Generic[T]):
465+
data: T
466+
467+
@dataclasses.dataclass
468+
class Nested:
469+
data: SimpleGeneric[int]
470+
471+
schema_s = class_schema(SimpleGeneric[str])()
472+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
473+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
474+
with self.assertRaises(ValidationError):
475+
schema_s.load({"data": 2})
476+
477+
schema_n = class_schema(Nested)()
478+
self.assertEqual(
479+
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
480+
)
481+
self.assertEqual(
482+
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
483+
)
484+
with self.assertRaises(ValidationError):
485+
schema_n.load({"data": {"data": "str"}})
486+
460487

461488
if __name__ == "__main__":
462489
unittest.main()

0 commit comments

Comments
 (0)