@@ -374,7 +374,9 @@ def class_schema(
374
374
>>> class_schema(Custom)().load({})
375
375
Custom(name=None)
376
376
"""
377
- if not dataclasses .is_dataclass (clazz ):
377
+ if not dataclasses .is_dataclass (clazz ) and not _is_generic_alias_of_dataclass (
378
+ clazz
379
+ ):
378
380
clazz = dataclasses .dataclass (clazz )
379
381
if not clazz_frame :
380
382
current_frame = inspect .currentframe ()
@@ -397,8 +399,7 @@ def _internal_class_schema(
397
399
) -> Type [marshmallow .Schema ]:
398
400
_RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
399
401
try :
400
- # noinspection PyDataclass
401
- fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
402
+ class_name , fields = _dataclass_name_and_fields (clazz )
402
403
except TypeError : # Not a dataclass
403
404
try :
404
405
warnings .warn (
@@ -448,7 +449,7 @@ def _internal_class_schema(
448
449
if field .init or include_non_init
449
450
)
450
451
451
- schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
452
+ schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
452
453
return cast (Type [marshmallow .Schema ], schema_class )
453
454
454
455
@@ -812,6 +813,47 @@ def _get_field_default(field: dataclasses.Field):
812
813
return field .default
813
814
814
815
816
+ def _is_generic_alias_of_dataclass (clazz : type ) -> bool :
817
+ """
818
+ Check if given class is a generic alias of a dataclass, if the dataclass is
819
+ defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
820
+ """
821
+ return typing_inspect .is_generic_type (clazz ) and dataclasses .is_dataclass (
822
+ typing_inspect .get_origin (clazz )
823
+ )
824
+
825
+
826
+ # noinspection PyDataclass
827
+ def _dataclass_name_and_fields (
828
+ clazz : type ,
829
+ ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
830
+ if not _is_generic_alias_of_dataclass (clazz ):
831
+ return clazz .__name__ , dataclasses .fields (clazz )
832
+
833
+ base_dataclass = typing_inspect .get_origin (clazz )
834
+ base_parameters = typing_inspect .get_parameters (base_dataclass )
835
+ type_arguments = typing_inspect .get_args (clazz )
836
+ params_to_args = dict (zip (base_parameters , type_arguments ))
837
+ non_generic_fields = [ # swap generic typed fields with types in given type arguments
838
+ (
839
+ f .name ,
840
+ params_to_args .get (f .type , f .type ),
841
+ dataclasses .field (
842
+ default = f .default ,
843
+ # ignoring mypy: https://github.com/python/mypy/issues/6910
844
+ default_factory = f .default_factory , # type: ignore
845
+ init = f .init ,
846
+ metadata = f .metadata ,
847
+ ),
848
+ )
849
+ for f in dataclasses .fields (base_dataclass )
850
+ ]
851
+ non_generic_dataclass = dataclasses .make_dataclass (
852
+ cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
853
+ )
854
+ return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
855
+
856
+
815
857
def NewType (
816
858
name : str ,
817
859
typ : Type [_U ],
0 commit comments