@@ -34,6 +34,7 @@ class User:
34
34
})
35
35
Schema: ClassVar[Type[Schema]] = Schema # For the type checker
36
36
"""
37
+
37
38
import collections .abc
38
39
import dataclasses
39
40
import inspect
@@ -61,13 +62,12 @@ class User:
61
62
TypeVar ,
62
63
Union ,
63
64
cast ,
64
- get_args ,
65
- get_origin ,
66
65
get_type_hints ,
67
66
overload ,
68
67
)
69
68
70
69
import marshmallow
70
+ import typing_extensions
71
71
import typing_inspect
72
72
73
73
from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
@@ -151,8 +151,7 @@ def dataclass(
151
151
frozen : bool = False ,
152
152
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
153
153
cls_frame : Optional [types .FrameType ] = None ,
154
- ) -> Type [_U ]:
155
- ...
154
+ ) -> Type [_U ]: ...
156
155
157
156
158
157
@overload
@@ -165,8 +164,7 @@ def dataclass(
165
164
frozen : bool = False ,
166
165
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
167
166
cls_frame : Optional [types .FrameType ] = None ,
168
- ) -> Callable [[Type [_U ]], Type [_U ]]:
169
- ...
167
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
170
168
171
169
172
170
# _cls should never be specified by keyword, so start it with an
@@ -225,15 +223,13 @@ def decorator(cls: Type[_U], stacklevel: int = 1) -> Type[_U]:
225
223
226
224
227
225
@overload
228
- def add_schema (_cls : Type [_U ]) -> Type [_U ]:
229
- ...
226
+ def add_schema (_cls : Type [_U ]) -> Type [_U ]: ...
230
227
231
228
232
229
@overload
233
230
def add_schema (
234
231
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
235
- ) -> Callable [[Type [_U ]], Type [_U ]]:
236
- ...
232
+ ) -> Callable [[Type [_U ]], Type [_U ]]: ...
237
233
238
234
239
235
@overload
@@ -242,8 +238,7 @@ def add_schema(
242
238
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
243
239
cls_frame : Optional [types .FrameType ] = None ,
244
240
stacklevel : int = 1 ,
245
- ) -> Type [_U ]:
246
- ...
241
+ ) -> Type [_U ]: ...
247
242
248
243
249
244
def add_schema (_cls = None , base_schema = None , cls_frame = None , stacklevel = 1 ):
@@ -294,8 +289,7 @@ def class_schema(
294
289
* ,
295
290
globalns : Optional [Dict [str , Any ]] = None ,
296
291
localns : Optional [Dict [str , Any ]] = None ,
297
- ) -> Type [marshmallow .Schema ]:
298
- ...
292
+ ) -> Type [marshmallow .Schema ]: ...
299
293
300
294
301
295
@overload
@@ -305,8 +299,7 @@ def class_schema(
305
299
clazz_frame : Optional [types .FrameType ] = None ,
306
300
* ,
307
301
globalns : Optional [Dict [str , Any ]] = None ,
308
- ) -> Type [marshmallow .Schema ]:
309
- ...
302
+ ) -> Type [marshmallow .Schema ]: ...
310
303
311
304
312
305
def class_schema (
@@ -514,7 +507,15 @@ def _internal_class_schema(
514
507
base_schema : Optional [Type [marshmallow .Schema ]] = None ,
515
508
) -> Type [marshmallow .Schema ]:
516
509
schema_ctx = _schema_ctx_stack .top
517
- schema_ctx .seen_classes [clazz ] = clazz .__name__
510
+
511
+ if typing_extensions .get_origin (clazz ) is Annotated and sys .version_info < (3 , 10 ):
512
+ # https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
513
+ class_name = clazz ._name or clazz .__origin__ .__name__ # type: ignore[attr-defined]
514
+ else :
515
+ class_name = clazz .__name__
516
+
517
+ schema_ctx .seen_classes [clazz ] = class_name
518
+
518
519
try :
519
520
# noinspection PyDataclass
520
521
fields : Tuple [dataclasses .Field , ...] = dataclasses .fields (clazz )
@@ -549,9 +550,18 @@ def _internal_class_schema(
549
550
include_non_init = getattr (getattr (clazz , "Meta" , None ), "include_non_init" , False )
550
551
551
552
# Update the schema members to contain marshmallow fields instead of dataclass fields
552
- type_hints = get_type_hints (
553
- clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns , include_extras = True ,
554
- )
553
+
554
+ if sys .version_info >= (3 , 9 ):
555
+ type_hints = get_type_hints (
556
+ clazz ,
557
+ globalns = schema_ctx .globalns ,
558
+ localns = schema_ctx .localns ,
559
+ include_extras = True ,
560
+ )
561
+ else :
562
+ type_hints = get_type_hints (
563
+ clazz , globalns = schema_ctx .globalns , localns = schema_ctx .localns
564
+ )
555
565
attributes .update (
556
566
(
557
567
field .name ,
@@ -642,8 +652,8 @@ def _field_for_generic_type(
642
652
"""
643
653
If the type is a generic interface, resolve the arguments and construct the appropriate Field.
644
654
"""
645
- origin = get_origin (typ )
646
- arguments = get_args (typ )
655
+ origin = typing_extensions . get_origin (typ )
656
+ arguments = typing_extensions . get_args (typ )
647
657
if origin :
648
658
# Override base_schema.TYPE_MAPPING to change the class used for generic types below
649
659
type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
@@ -889,7 +899,7 @@ def _field_for_schema(
889
899
)
890
900
891
901
# enumerations
892
- if issubclass (typ , Enum ):
902
+ if inspect . isclass ( typ ) and issubclass (typ , Enum ):
893
903
return marshmallow .fields .Enum (typ , ** metadata )
894
904
895
905
# Nested marshmallow dataclass
0 commit comments