@@ -453,7 +453,9 @@ def class_schema(
453
453
>>> class_schema(Custom)().load({})
454
454
Custom(name=None)
455
455
"""
456
- if not dataclasses .is_dataclass (clazz ):
456
+ if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
457
+ clazz
458
+ ):
457
459
clazz = dataclasses .dataclass (clazz )
458
460
if localns is None :
459
461
if clazz_frame is None :
@@ -523,8 +525,7 @@ def _internal_class_schema(
523
525
schema_ctx .seen_classes [clazz ] = class_name
524
526
525
527
try :
526
- # noinspection PyDataclass
527
- fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
528
+ class_name , fields = _dataclass_name_and_fields (clazz )
528
529
except TypeError : # Not a dataclass
529
530
try :
530
531
warnings .warn (
@@ -582,7 +583,7 @@ def _internal_class_schema(
582
583
if field .init or include_non_init
583
584
)
584
585
585
- schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
586
+ schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
586
587
return cast (Type [marshmallow .Schema ], schema_class )
587
588
588
589
@@ -996,6 +997,47 @@ def _get_field_default(field: dataclasses.Field):
996
997
return field .default
997
998
998
999
1000
+ def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
1001
+ """
1002
+ Check if given class is a generic alias of a dataclass, if the dataclass is
1003
+ defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1004
+ """
1005
+ return typing_inspect .is_generic_type (clazz ) and dataclasses .is_dataclass (
1006
+ typing_inspect .get_origin (clazz )
1007
+ )
1008
+
1009
+
1010
+ # noinspection PyDataclass
1011
+ def _dataclass_name_and_fields (
1012
+ clazz : type ,
1013
+ ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
1014
+ if not _is_generic_alias_of_dataclass (clazz ):
1015
+ return clazz .__name__ , dataclasses .fields (clazz )
1016
+
1017
+ base_dataclass = typing_inspect .get_origin (clazz )
1018
+ base_parameters = typing_inspect .get_parameters (base_dataclass )
1019
+ type_arguments = typing_inspect .get_args (clazz )
1020
+ params_to_args = dict (zip (base_parameters , type_arguments ))
1021
+ non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022
+ (
1023
+ f .name ,
1024
+ params_to_args .get (f .type , f .type ),
1025
+ dataclasses .field (
1026
+ default = f .default ,
1027
+ # ignoring mypy: https://github.com/python/mypy/issues/6910
1028
+ default_factory = f .default_factory , # type: ignore
1029
+ init = f .init ,
1030
+ metadata = f .metadata ,
1031
+ ),
1032
+ )
1033
+ for f in dataclasses .fields (base_dataclass )
1034
+ ]
1035
+ non_generic_dataclass = dataclasses .make_dataclass (
1036
+ cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
1037
+ )
1038
+ return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
1039
+
1040
+
999
1041
def NewType (
1000
1042
name : str ,
1001
1043
typ : Type [_U ],
0 commit comments