@@ -44,15 +44,9 @@ class User:
44
44
import warnings
45
45
from enum import Enum
46
46
from functools import lru_cache , partial
47
+ from typing import Any , Callable , Dict , FrozenSet , Generic , List , Mapping
48
+ from typing import NewType as typing_NewType
47
49
from typing import (
48
- Any ,
49
- Callable ,
50
- Dict ,
51
- FrozenSet ,
52
- Generic ,
53
- List ,
54
- Mapping ,
55
- NewType as typing_NewType ,
56
50
Optional ,
57
51
Sequence ,
58
52
Set ,
@@ -150,8 +144,7 @@ def dataclass(
150
144
frozen : bool = False ,
151
145
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
152
146
cls_frame : Optional [types .FrameType ] = None ,
153
- ) -> Type [_U ]:
154
- ...
147
+ ) -> Type [_U ]: ...
155
148
156
149
157
150
@overload
@@ -164,8 +157,7 @@ def dataclass(
164
157
frozen : bool = False ,
165
158
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
166
159
cls_frame : Optional [types .FrameType ] = None ,
167
- ) -> Callable [[Type [_U ]], Type [_U ]]:
168
- ...
160
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
169
161
170
162
171
163
# _cls should never be specified by keyword, so start it with an
@@ -224,15 +216,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
224
216
225
217
226
218
@overload
227
- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
228
- ...
219
+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
229
220
230
221
231
222
@overload
232
223
def add_schema (
233
224
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
234
- ) -> Callable [[Type [_U ]], Type [_U ]]:
235
- ...
225
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
236
226
237
227
238
228
@overload
@@ -241,8 +231,7 @@ def add_schema(
241
231
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
242
232
cls_frame : Optional [types .FrameType ] = None ,
243
233
stacklevel : int = 1 ,
244
- ) -> Type [_U ]:
245
- ...
234
+ ) -> Type [_U ]: ...
246
235
247
236
248
237
def add_schema (_cls = None , base_schema = None , cls_frame = None , stacklevel = 1 ):
@@ -293,8 +282,7 @@ def class_schema(
293
282
* ,
294
283
globalns : Optional [Dict [str , Any ]] = None ,
295
284
localns : Optional [Dict [str , Any ]] = None ,
296
- ) -> Type [marshmallow .Schema ]:
297
- ...
285
+ ) -> Type [marshmallow .Schema ]: ...
298
286
299
287
300
288
@overload
@@ -304,8 +292,7 @@ def class_schema(
304
292
clazz_frame : Optional [types .FrameType ] = None ,
305
293
* ,
306
294
globalns : Optional [Dict [str , Any ]] = None ,
307
- ) -> Type [marshmallow .Schema ]:
308
- ...
295
+ ) -> Type [marshmallow .Schema ]: ...
309
296
310
297
311
298
def class_schema (
@@ -463,7 +450,7 @@ def class_schema(
463
450
if clazz_frame is not None :
464
451
localns = clazz_frame .f_locals
465
452
with _SchemaContext (globalns , localns ):
466
- return _internal_class_schema (clazz , base_schema )
453
+ return _internal_class_schema (clazz , base_schema , None )
467
454
468
455
469
456
class _SchemaContext :
@@ -509,10 +496,17 @@ def top(self) -> _U:
509
496
_schema_ctx_stack = _LocalStack [_SchemaContext ]()
510
497
511
498
499
+ def _dataclass_fields (clazz : type ) -> Tuple [dataclasses .Field , ...]:
500
+ if _is_generic_alias_of_dataclass (clazz ):
501
+ clazz = typing_inspect .get_origin (clazz )
502
+ return dataclasses .fields (clazz )
503
+
504
+
512
505
@lru_cache (maxsize = MAX_CLASS_SCHEMA_CACHE_SIZE )
513
506
def _internal_class_schema (
514
507
clazz : type ,
515
508
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
509
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
516
510
) -> Type [marshmallow .Schema ]:
517
511
schema_ctx = _schema_ctx_stack .top
518
512
@@ -525,7 +519,7 @@ def _internal_class_schema(
525
519
schema_ctx .seen_classes [clazz ] = class_name
526
520
527
521
try :
528
- class_name , fields = _dataclass_name_and_fields (clazz )
522
+ fields = _dataclass_fields (clazz )
529
523
except TypeError : # Not a dataclass
530
524
try :
531
525
warnings .warn (
@@ -540,7 +534,9 @@ def _internal_class_schema(
540
534
"****** WARNING ******"
541
535
)
542
536
created_dataclass : type = dataclasses .dataclass (clazz )
543
- return _internal_class_schema (created_dataclass , base_schema )
537
+ return _internal_class_schema (
538
+ created_dataclass , base_schema , generic_params_to_args
539
+ )
544
540
except Exception as exc :
545
541
raise TypeError (
546
542
f"{ getattr (clazz , '__name__' , repr (clazz ))} is not a dataclass and cannot be turned into one."
@@ -556,6 +552,10 @@ def _internal_class_schema(
556
552
# Determine whether we should include non-init fields
557
553
include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
558
554
555
+ if _is_generic_alias_of_dataclass (clazz ) and generic_params_to_args is None :
556
+ generic_params_to_args = _generic_params_to_args (clazz )
557
+
558
+ type_hints = _dataclass_type_hints (clazz , schema_ctx , generic_params_to_args )
559
559
# Update the schema members to contain marshmallow fields instead of dataclass fields
560
560
561
561
if sys .version_info >= (3 , 9 ):
@@ -577,13 +577,14 @@ def _internal_class_schema(
577
577
_get_field_default (field ),
578
578
field .metadata ,
579
579
base_schema ,
580
+ generic_params_to_args ,
580
581
),
581
582
)
582
583
for field in fields
583
584
if field .init or include_non_init
584
585
)
585
586
586
- schema_class = type (class_name , (_base_schema (clazz , base_schema ),), attributes )
587
+ schema_class = type (clazz . __name__ , (_base_schema (clazz , base_schema ),), attributes )
587
588
return cast (Type [marshmallow .Schema ], schema_class )
588
589
589
590
@@ -706,7 +707,7 @@ def _field_for_generic_type(
706
707
),
707
708
)
708
709
return tuple_type (children , ** metadata )
709
- elif origin in (dict , Dict , collections .abc .Mapping , Mapping ):
710
+ if origin in (dict , Dict , collections .abc .Mapping , Mapping ):
710
711
dict_type = type_mapping .get (Dict , marshmallow .fields .Dict )
711
712
return dict_type (
712
713
keys = _field_for_schema (arguments [0 ], base_schema = base_schema ),
@@ -794,6 +795,7 @@ def field_for_schema(
794
795
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
795
796
# FIXME: delete typ_frame from API?
796
797
typ_frame : Optional [types .FrameType ] = None ,
798
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
797
799
) -> marshmallow .fields .Field :
798
800
"""
799
801
Get a marshmallow Field corresponding to the given python type.
@@ -953,7 +955,7 @@ def _field_for_schema(
953
955
nested_schema
954
956
or forward_reference
955
957
or _schema_ctx_stack .top .seen_classes .get (typ )
956
- or _internal_class_schema (typ , base_schema ) # type: ignore[arg-type] # FIXME
958
+ or _internal_class_schema (typ , base_schema , generic_params_to_args ) # type: ignore [arg-type]
957
959
)
958
960
959
961
return marshmallow .fields .Nested (nested , ** metadata )
@@ -1007,35 +1009,38 @@ def _is_generic_alias_of_dataclass(clazz: type) -> bool:
1007
1009
)
1008
1010
1009
1011
1010
- # noinspection PyDataclass
1011
- def _dataclass_name_and_fields (
1012
- clazz : type ,
1013
- ) -> Tuple [str , Tuple [dataclasses .Field , ...]]:
1014
- if not _is_generic_alias_of_dataclass (clazz ):
1015
- return clazz .__name__ , dataclasses .fields (clazz )
1016
-
1012
+ def _generic_params_to_args (clazz : type ) -> Tuple [Tuple [type , type ], ...]:
1017
1013
base_dataclass = typing_inspect .get_origin (clazz )
1018
1014
base_parameters = typing_inspect .get_parameters (base_dataclass )
1019
1015
type_arguments = typing_inspect .get_args (clazz )
1020
- params_to_args = dict (zip (base_parameters , type_arguments ))
1021
- non_generic_fields = [ # swap generic typed fields with types in given type arguments
1022
- (
1023
- f .name ,
1024
- params_to_args .get (f .type , f .type ),
1025
- dataclasses .field (
1026
- default = f .default ,
1027
- # ignoring mypy: https://github.com/python/mypy/issues/6910
1028
- default_factory = f .default_factory , # type: ignore
1029
- init = f .init ,
1030
- metadata = f .metadata ,
1031
- ),
1016
+ return tuple (zip (base_parameters , type_arguments ))
1017
+
1018
+
1019
+ def _dataclass_type_hints (
1020
+ clazz : type ,
1021
+ schema_ctx : _SchemaContext = None ,
1022
+ generic_params_to_args : Optional [Tuple [Tuple [type , type ], ...]] = None ,
1023
+ ) -> Mapping [str , type ]:
1024
+ if not _is_generic_alias_of_dataclass (clazz ):
1025
+ return get_type_hints (
1026
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
1032
1027
)
1033
- for f in dataclasses .fields (base_dataclass )
1034
- ]
1035
- non_generic_dataclass = dataclasses .make_dataclass (
1036
- cls_name = f"{ base_dataclass .__name__ } { type_arguments } " , fields = non_generic_fields
1028
+ # dataclass is generic
1029
+ generic_type_hints = get_type_hints (
1030
+ typing_inspect .get_origin (clazz ),
1031
+ globalns = schema_ctx .globalns ,
1032
+ localns = schema_ctx .localns ,
1037
1033
)
1038
- return base_dataclass .__name__ , dataclasses .fields (non_generic_dataclass )
1034
+ generic_params_map = dict (generic_params_to_args if generic_params_to_args else {})
1035
+
1036
+ def _get_hint (_t : type ) -> type :
1037
+ if isinstance (_t , TypeVar ):
1038
+ return generic_params_map [_t ]
1039
+ return _t
1040
+
1041
+ return {
1042
+ field_name : _get_hint (typ ) for field_name , typ in generic_type_hints .items ()
1043
+ }
1039
1044
1040
1045
1041
1046
def NewType (
0 commit comments