Skip to content

Commit bc46a23

Browse files
onursaticimvanderlee
authored andcommitted
support generic dataclasses
1 parent 4edbfb4 commit bc46a23

File tree

2 files changed

+73
-4
lines changed

2 files changed

+73
-4
lines changed

marshmallow_dataclass/__init__.py

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,9 @@ def class_schema(
453453
>>> class_schema(Custom)().load({})
454454
Custom(name=None)
455455
"""
456-
if not dataclasses.is_dataclass(clazz):
456+
if not dataclasses.is_dataclass(clazz) and not _is_generic_alias_of_dataclass(
457+
clazz
458+
):
457459
clazz = dataclasses.dataclass(clazz)
458460
if localns is None:
459461
if clazz_frame is None:
@@ -523,8 +525,7 @@ def _internal_class_schema(
523525
schema_ctx.seen_classes[clazz] = class_name
524526

525527
try:
526-
# noinspection PyDataclass
527-
fields: Tuple[dataclasses.Field, ...] = dataclasses.fields(clazz)
528+
class_name, fields = _dataclass_name_and_fields(clazz)
528529
except TypeError: # Not a dataclass
529530
try:
530531
warnings.warn(
@@ -582,7 +583,7 @@ def _internal_class_schema(
582583
if field.init or include_non_init
583584
)
584585

585-
schema_class = type(clazz.__name__, (_base_schema(clazz, base_schema),), attributes)
586+
schema_class = type(class_name, (_base_schema(clazz, base_schema),), attributes)
586587
return cast(Type[marshmallow.Schema], schema_class)
587588

588589

@@ -996,6 +997,47 @@ def _get_field_default(field: dataclasses.Field):
996997
return field.default
997998

998999

1000+
def _is_generic_alias_of_dataclass(clazz: type) -> bool:
1001+
"""
1002+
Check if given class is a generic alias of a dataclass, if the dataclass is
1003+
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1004+
"""
1005+
return typing_inspect.is_generic_type(clazz) and dataclasses.is_dataclass(
1006+
typing_inspect.get_origin(clazz)
1007+
)
1008+
1009+
1010+
# noinspection PyDataclass
1011+
def _dataclass_name_and_fields(
1012+
clazz: type,
1013+
) -> Tuple[str, Tuple[dataclasses.Field, ...]]:
1014+
if not _is_generic_alias_of_dataclass(clazz):
1015+
return clazz.__name__, dataclasses.fields(clazz)
1016+
1017+
base_dataclass = typing_inspect.get_origin(clazz)
1018+
base_parameters = typing_inspect.get_parameters(base_dataclass)
1019+
type_arguments = typing_inspect.get_args(clazz)
1020+
params_to_args = dict(zip(base_parameters, type_arguments))
1021+
non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022+
(
1023+
f.name,
1024+
params_to_args.get(f.type, f.type),
1025+
dataclasses.field(
1026+
default=f.default,
1027+
# ignoring mypy: https://github.com/python/mypy/issues/6910
1028+
default_factory=f.default_factory, # type: ignore
1029+
init=f.init,
1030+
metadata=f.metadata,
1031+
),
1032+
)
1033+
for f in dataclasses.fields(base_dataclass)
1034+
]
1035+
non_generic_dataclass = dataclasses.make_dataclass(
1036+
cls_name=f"{base_dataclass.__name__}{type_arguments}", fields=non_generic_fields
1037+
)
1038+
return base_dataclass.__name__, dataclasses.fields(non_generic_dataclass)
1039+
1040+
9991041
def NewType(
10001042
name: str,
10011043
typ: Type[_U],

tests/test_class_schema.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,33 @@ class Meta:
457457
self.assertNotIn("no_init", class_schema(NoInit)().fields)
458458
self.assertIn("no_init", class_schema(Init)().fields)
459459

460+
def test_generic_dataclass(self):
461+
T = typing.TypeVar("T")
462+
463+
@dataclasses.dataclass
464+
class SimpleGeneric(typing.Generic[T]):
465+
data: T
466+
467+
@dataclasses.dataclass
468+
class Nested:
469+
data: SimpleGeneric[int]
470+
471+
schema_s = class_schema(SimpleGeneric[str])()
472+
self.assertEqual(SimpleGeneric(data="a"), schema_s.load({"data": "a"}))
473+
self.assertEqual(schema_s.dump(SimpleGeneric(data="a")), {"data": "a"})
474+
with self.assertRaises(ValidationError):
475+
schema_s.load({"data": 2})
476+
477+
schema_n = class_schema(Nested)()
478+
self.assertEqual(
479+
Nested(data=SimpleGeneric(1)), schema_n.load({"data": {"data": 1}})
480+
)
481+
self.assertEqual(
482+
schema_n.dump(Nested(data=SimpleGeneric(data=1))), {"data": {"data": 1}}
483+
)
484+
with self.assertRaises(ValidationError):
485+
schema_n.load({"data": {"data": "str"}})
486+
460487

461488
if __name__ == "__main__":
462489
unittest.main()

0 commit comments

Comments
 (0)