@@ -66,7 +66,6 @@ class User:
66
66
)
67
67
68
68
import marshmallow
69
- import typing_extensions
70
69
import typing_inspect
71
70
72
71
from marshmallow_dataclass .generic_resolver import (
@@ -78,9 +77,9 @@ class User:
78
77
from marshmallow_dataclass .lazy_class_attribute import lazy_class_attribute
79
78
80
79
if sys .version_info >= (3 , 9 ):
81
- from typing import Annotated
80
+ from typing import Annotated , get_args , get_origin
82
81
else :
83
- from typing_extensions import Annotated
82
+ from typing_extensions import Annotated , get_args , get_origin
84
83
85
84
if sys .version_info >= (3 , 11 ):
86
85
from typing import dataclass_transform
@@ -540,7 +539,7 @@ def _internal_class_schema(
540
539
) -> Type [marshmallow .Schema ]:
541
540
schema_ctx = _schema_ctx_stack .top
542
541
543
- if typing_extensions . get_origin (clazz ) is Annotated and sys .version_info < (3 , 10 ):
542
+ if get_origin (clazz ) is Annotated and sys .version_info < (3 , 10 ):
544
543
# https://github.com/python/cpython/blob/3.10/Lib/typing.py#L977
545
544
class_name = clazz ._name or clazz .__origin__ .__name__ # type: ignore[attr-defined]
546
545
else :
@@ -597,7 +596,7 @@ def _internal_class_schema(
597
596
(
598
597
type_hints [field .name ]
599
598
if not is_generic_type (clazz )
600
- else _get_type_hint_of_generic_object (field .type , schema_ctx )
599
+ else _resolve_forward_type_refs (field .type , schema_ctx )
601
600
),
602
601
_get_field_default (field ),
603
602
field .metadata ,
@@ -659,8 +658,8 @@ def _field_by_supertype(
659
658
)
660
659
661
660
662
- def _generic_type_add_any (typ : type ) -> type :
663
- """if typ is generic type without arguments, replace them by Any."""
661
+ def _container_type_add_any (typ : type ) -> type :
662
+ """if typ is container type without arguments, replace them by Any."""
664
663
if typ is list or typ is List :
665
664
typ = List [Any ]
666
665
elif typ is dict or typ is Dict :
@@ -676,18 +675,20 @@ def _generic_type_add_any(typ: type) -> type:
676
675
return typ
677
676
678
677
679
- def _field_for_generic_type (
678
+ def _field_for_container_type (
680
679
typ : type ,
681
680
base_schema : Optional [Type [marshmallow .Schema ]],
682
681
** metadata : Any ,
683
682
) -> Optional [marshmallow .fields .Field ]:
684
683
"""
685
- If the type is a generic interface, resolve the arguments and construct the appropriate Field.
684
+ If the type is a container interface, resolve the arguments and construct the appropriate Field.
685
+
686
+ We use the term 'container' to differentiate from the Generic support
686
687
"""
687
- origin = typing_extensions . get_origin (typ )
688
- arguments = typing_extensions . get_args (typ )
688
+ origin = get_origin (typ )
689
+ arguments = get_args (typ )
689
690
if origin :
690
- # Override base_schema.TYPE_MAPPING to change the class used for generic types below
691
+ # Override base_schema.TYPE_MAPPING to change the class used for container types below
691
692
type_mapping = base_schema .TYPE_MAPPING if base_schema else {}
692
693
693
694
if origin in (list , List ):
@@ -749,18 +750,15 @@ def _field_for_annotated_type(
749
750
"""
750
751
If the type is an Annotated interface, resolve the arguments and construct the appropriate Field.
751
752
"""
752
- origin = typing_extensions . get_origin (typ )
753
- arguments = typing_extensions . get_args (typ )
753
+ origin = get_origin (typ )
754
+ arguments = get_args (typ )
754
755
if origin and origin is Annotated :
755
756
marshmallow_annotations = [
756
757
arg
757
758
for arg in arguments [1 :]
758
759
if _is_marshmallow_field (arg )
759
760
# Support `CustomGenericField[mf.String]`
760
- or (
761
- is_generic_type (arg )
762
- and _is_marshmallow_field (typing_extensions .get_origin (arg ))
763
- )
761
+ or (is_generic_type (arg ) and _is_marshmallow_field (get_origin (arg )))
764
762
]
765
763
if marshmallow_annotations :
766
764
if len (marshmallow_annotations ) > 1 :
@@ -782,7 +780,7 @@ def _field_for_union_type(
782
780
base_schema : Optional [Type [marshmallow .Schema ]],
783
781
** metadata : Any ,
784
782
) -> Optional [marshmallow .fields .Field ]:
785
- arguments = typing_extensions . get_args (typ )
783
+ arguments = get_args (typ )
786
784
if typing_inspect .is_union_type (typ ):
787
785
if typing_inspect .is_optional_type (typ ):
788
786
metadata ["allow_none" ] = metadata .get ("allow_none" , True )
@@ -886,8 +884,8 @@ def _field_for_schema(
886
884
if predefined_field :
887
885
return predefined_field
888
886
889
- # Generic types specified without type arguments
890
- typ = _generic_type_add_any (typ )
887
+ # Container types (generics like List) specified without type arguments
888
+ typ = _container_type_add_any (typ )
891
889
892
890
# Base types
893
891
field = _field_by_type (typ , base_schema )
@@ -900,7 +898,7 @@ def _field_for_schema(
900
898
901
899
# i.e.: Literal['abc']
902
900
if typing_inspect .is_literal_type (typ ):
903
- arguments = typing_extensions . get_args (typ )
901
+ arguments = get_args (typ )
904
902
return marshmallow .fields .Raw (
905
903
validate = (
906
904
marshmallow .validate .Equal (arguments [0 ])
@@ -912,7 +910,7 @@ def _field_for_schema(
912
910
913
911
# i.e.: Final[str] = 'abc'
914
912
if typing_inspect .is_final_type (typ ):
915
- arguments = typing_extensions . get_args (typ )
913
+ arguments = get_args (typ )
916
914
if arguments :
917
915
subtyp = arguments [0 ]
918
916
elif default is not marshmallow .missing :
@@ -953,10 +951,10 @@ def _field_for_schema(
953
951
if union_field :
954
952
return union_field
955
953
956
- # Generic types
957
- generic_field = _field_for_generic_type (typ , base_schema , ** metadata )
958
- if generic_field :
959
- return generic_field
954
+ # Container types
955
+ container_field = _field_for_container_type (typ , base_schema , ** metadata )
956
+ if container_field :
957
+ return container_field
960
958
961
959
# typing.NewType returns a function (in python <= 3.9) or a class (python >= 3.10) with a
962
960
# __supertype__ attribute
@@ -1034,9 +1032,7 @@ def is_generic_alias_of_dataclass(clazz: type) -> bool:
1034
1032
Check if given class is a generic alias of a dataclass, if the dataclass is
1035
1033
defined as `class A(Generic[T])`, this method will return true if `A[int]` is passed
1036
1034
"""
1037
- return is_generic_alias (clazz ) and dataclasses .is_dataclass (
1038
- typing_extensions .get_origin (clazz )
1039
- )
1035
+ return is_generic_alias (clazz ) and dataclasses .is_dataclass (get_origin (clazz ))
1040
1036
1041
1037
1042
1038
def _get_type_hints (
@@ -1058,11 +1054,13 @@ def _get_type_hints(
1058
1054
return type_hints
1059
1055
1060
1056
1061
- def _get_type_hint_of_generic_object (
1057
+ def _resolve_forward_type_refs (
1062
1058
obj ,
1063
1059
schema_ctx : _SchemaContext ,
1064
1060
) -> type :
1065
- """typing.get_type_hints doesn't work with generic aliases, i.e.: A[int]. But this 'hack' works."""
1061
+ """
1062
+ Resolve forward references, mainly applies to Generics i.e.: `A["int"]` -> `A[int]`
1063
+ """
1066
1064
1067
1065
class X :
1068
1066
x : obj # type: ignore[name-defined]
0 commit comments