@@ -34,6 +34,7 @@ class User:
34
34
})
35
35
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
36
36
"""
37
+
37
38
import collections .abc
38
39
import dataclasses
39
40
import inspect
@@ -43,14 +44,11 @@ class User:
43
44
import warnings
44
45
from enum import Enum
45
46
from functools import lru_cache , partial
47
+ from typing import Any , Callable , Dict , FrozenSet , List , Mapping
48
+ from typing import NewType as typing_NewType
46
49
from typing import (
47
- Any ,
48
- Callable ,
49
- Dict ,
50
- List ,
51
- Mapping ,
52
- NewType as typing_NewType ,
53
50
Optional ,
51
+ Sequence ,
54
52
Set ,
55
53
Tuple ,
56
54
Type ,
@@ -59,16 +57,13 @@ class User:
59
57
cast ,
60
58
get_type_hints ,
61
59
overload ,
62
- Sequence ,
63
- FrozenSet ,
64
60
)
65
61
66
62
import marshmallow
67
63
import typing_inspect
68
64
69
65
from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
70
66
71
-
72
67
if sys .version_info >= (3 , 11 ):
73
68
from typing import dataclass_transform
74
69
elif sys .version_info >= (3 , 7 ):
@@ -105,8 +100,7 @@ def dataclass(
105
100
frozen : bool = False ,
106
101
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
107
102
cls_frame : Optional [types .FrameType ] = None ,
108
- ) -> Type [_U ]:
109
- ...
103
+ ) -> Type [_U ]: ...
110
104
111
105
112
106
@overload
@@ -119,8 +113,7 @@ def dataclass(
119
113
frozen : bool = False ,
120
114
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
121
115
cls_frame : Optional [types .FrameType ] = None ,
122
- ) -> Callable [[Type [_U ]], Type [_U ]]:
123
- ...
116
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
124
117
125
118
126
119
# _cls should never be specified by keyword, so start it with an
@@ -179,24 +172,21 @@ def dataclass(
179
172
180
173
181
174
@overload
182
- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
183
- ...
175
+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
184
176
185
177
186
178
@overload
187
179
def add_schema (
188
180
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
189
- ) -> Callable [[Type [_U ]], Type [_U ]]:
190
- ...
181
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
191
182
192
183
193
184
@overload
194
185
def add_schema (
195
186
_cls : Type [_U ],
196
187
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
197
188
cls_frame : Optional [types .FrameType ] = None ,
198
- ) -> Type [_U ]:
199
- ...
189
+ ) -> Type [_U ]: ...
200
190
201
191
202
192
def add_schema (_cls = None , base_schema = None , cls_frame = None ):
@@ -386,20 +376,27 @@ def class_schema(
386
376
del current_frame
387
377
_RECURSION_GUARD .seen_classes = {}
388
378
try :
389
- return _internal_class_schema (clazz , base_schema , clazz_frame )
379
+ return _internal_class_schema (clazz , base_schema , clazz_frame , None )
390
380
finally :
391
381
_RECURSION_GUARD .seen_classes .clear ()
392
382
393
383
384
+ def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
385
+ if _is_generic_alias_of_dataclass (clazz ):
386
+ clazz = typing_inspect .get_origin (clazz )
387
+ return dataclasses .fields (clazz )
388
+
389
+
394
390
@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
395
391
def _internal_class_schema (
396
392
clazz : type ,
397
393
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
398
394
clazz_frame : Optional [types .FrameType ] = None ,
395
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
399
396
) -> Type [marshmallow .Schema ]:
400
397
_RECURSION_GUARD .seen_classes [clazz ] = clazz .__name__
401
398
try :
402
- class_name , fields = _dataclass_name_and_fields (clazz )
399
+ fields = _dataclass_fields (clazz )
403
400
except TypeError : # Not a dataclass
404
401
try :
405
402
warnings .warn (
@@ -414,7 +411,9 @@ def _internal_class_schema(
414
411
"****** WARNING ******"
415
412
)
416
413
created_dataclass : type = dataclasses .dataclass (clazz )
417
- return _internal_class_schema (created_dataclass , base_schema , clazz_frame )
414
+ return _internal_class_schema (
415
+ created_dataclass , base_schema , clazz_frame , generic_params_to_args
416
+ )
418
417
except Exception as exc :
419
418
raise TypeError (
420
419
f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -430,10 +429,11 @@ def _internal_class_schema(
430
429
# Determine whether we should include non-init fields
431
430
include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
432
431
432
+ if _is_generic_alias_of_dataclass (clazz ) and generic_params_to_args is None :
433
+ generic_params_to_args = _generic_params_to_args (clazz )
434
+
435
+ type_hints = _dataclass_type_hints (clazz , clazz_frame , generic_params_to_args )
433
436
# Update the schema members to contain marshmallow fields instead of dataclass fields
434
- type_hints = get_type_hints (
435
- clazz , localns = clazz_frame .f_locals if clazz_frame else None
436
- )
437
437
attributes .update (
438
438
(
439
439
field .name ,
@@ -443,13 +443,14 @@ def _internal_class_schema(
443
443
field .metadata ,
444
444
base_schema ,
445
445
clazz_frame ,
446
+ generic_params_to_args ,
446
447
),
447
448
)
448
449
for field in fields
449
450
if field .init or include_non_init
450
451
)
451
452
452
- schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
453
+ schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
453
454
return cast (Type [marshmallow .Schema ], schema_class )
454
455
455
456
@@ -584,7 +585,7 @@ def _field_for_generic_type(
584
585
),
585
586
)
586
587
return tuple_type (children , ** metadata )
587
- elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
588
+ if origin in (dict , Dict , collections .abc .Mapping , Mapping ):
588
589
dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
589
590
return dict_type (
590
591
keys = field_for_schema (
@@ -636,6 +637,7 @@ def field_for_schema(
636
637
metadata : Optional [Mapping [str , Any ]] = None ,
637
638
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
638
639
typ_frame : Optional [types .FrameType ] = None ,
640
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
639
641
) -> marshmallow .fields .Field :
640
642
"""
641
643
Get a marshmallow Field corresponding to the given python type.
@@ -769,7 +771,7 @@ def field_for_schema(
769
771
nested_schema
770
772
or forward_reference
771
773
or _RECURSION_GUARD .seen_classes .get (typ )
772
- or _internal_class_schema (typ , base_schema , typ_frame ) # type: ignore [arg-type]
774
+ or _internal_class_schema (typ , base_schema , typ_frame , generic_params_to_args ) # type: ignore [arg-type]
773
775
)
774
776
775
777
return marshmallow .fields .Nested (nested , ** metadata )
@@ -823,35 +825,33 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
823
825
)
824
826
825
827
826
- # noinspection PyDataclass
827
- def _dataclass_name_and_fields (
828
- clazz : type ,
829
- ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
830
- if not _is_generic_alias_of_dataclass (clazz ):
831
- return clazz .__name__ , dataclasses .fields (clazz )
832
-
828
+ def _generic_params_to_args (clazz : type ) -> Tuple [Tuple [type , type ], ...]:
833
829
base_dataclass = typing_inspect .get_origin (clazz )
834
830
base_parameters = typing_inspect .get_parameters (base_dataclass )
835
831
type_arguments = typing_inspect .get_args (clazz )
836
- params_to_args = dict (zip (base_parameters , type_arguments ))
837
- non_generic_fields = [ # swap generic typed fields with types in given type arguments
838
- (
839
- f .name ,
840
- params_to_args .get (f .type , f .type ),
841
- dataclasses .field (
842
- default = f .default ,
843
- # ignoring mypy: https://github.com/python/mypy/issues/6910
844
- default_factory = f .default_factory , # type: ignore
845
- init = f .init ,
846
- metadata = f .metadata ,
847
- ),
848
- )
849
- for f in dataclasses .fields (base_dataclass )
850
- ]
851
- non_generic_dataclass = dataclasses .make_dataclass (
852
- cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
853
- )
854
- return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
832
+ return tuple (zip (base_parameters , type_arguments ))
833
+
834
+
835
+ def _dataclass_type_hints (
836
+ clazz : type ,
837
+ clazz_frame : types .FrameType = None ,
838
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
839
+ ) -> Mapping [str , type ]:
840
+ localns = clazz_frame .f_locals if clazz_frame else None
841
+ if not _is_generic_alias_of_dataclass (clazz ):
842
+ return get_type_hints (clazz , localns = localns )
843
+ # dataclass is generic
844
+ generic_type_hints = get_type_hints (typing_inspect .get_origin (clazz ), localns )
845
+ generic_params_map = dict (generic_params_to_args if generic_params_to_args else {})
846
+
847
+ def _get_hint (_t : type ) -> type :
848
+ if isinstance (_t , TypeVar ):
849
+ return generic_params_map [_t ]
850
+ return _t
851
+
852
+ return {
853
+ field_name : _get_hint (typ ) for field_name , typ in generic_type_hints .items ()
854
+ }
855
855
856
856
857
857
def NewType (
0 commit comments